Skip to content

Integrations

zenml.integrations special

ZenML integrations module.

The ZenML integrations module contains sub-modules for each integration that we support. This includes orchestrators like Apache Airflow, visualization tools like the facets library, as well as deep learning libraries like PyTorch.

airflow special

Airflow integration for ZenML.

The Airflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Airflow orchestrator with the CLI tool, then bootstrap using the zenml orchestrator up command.

AirflowIntegration (Integration)

Definition of Airflow Integration for ZenML.

Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
    """Definition of Airflow Integration for ZenML."""

    NAME = AIRFLOW
    REQUIREMENTS = ["apache-airflow==2.2.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Airflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AIRFLOW_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Airflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Airflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AIRFLOW_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        )
    ]

orchestrators special

The Airflow integration enables the use of Airflow as a pipeline orchestrator.

airflow_orchestrator

Implementation of Airflow orchestrator integration.

AirflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Airflow."""

    airflow_home: str = ""

    # Class Configuration
    FLAVOR: ClassVar[str] = AIRFLOW_ORCHESTRATOR_FLAVOR

    def __init__(self, **values: Any):
        """Sets environment variables to configure airflow.

        Args:
            **values: Values to set in the orchestrator.
        """
        super().__init__(**values)
        self._set_env()

    @staticmethod
    def _translate_schedule(
        schedule: Optional[Schedule] = None,
    ) -> Dict[str, Any]:
        """Convert ZenML schedule into Airflow schedule.

        The Airflow schedule uses slightly different naming and needs some
        default entries for execution without a schedule.

        Args:
            schedule: Containing the interval, start and end date and
                a boolean flag that defines if past runs should be caught up
                on

        Returns:
            Airflow configuration dict.
        """
        if schedule:
            if schedule.cron_expression:
                return {
                    "schedule_interval": schedule.cron_expression,
                }
            else:
                return {
                    "schedule_interval": schedule.interval_second,
                    "start_date": schedule.start_time,
                    "end_date": schedule.end_time,
                    "catchup": schedule.catchup,
                }

        return {
            "schedule_interval": "@once",
            # set the a start time in the past and disable catchup so airflow runs the dag immediately
            "start_date": datetime.datetime.now() - datetime.timedelta(7),
            "catchup": False,
        }

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List[BaseStep],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates an Airflow DAG as the intermediate representation for the pipeline.

        This DAG will be loaded by airflow in the target environment
        and used for orchestration of the pipeline.

        How it works:
        -------------
        A new airflow_dag is instantiated with the pipeline name and among
        others things the run schedule.

        For each step of the pipeline a callable is created. This callable
        uses the run_step() method to execute the step. The parameters of
        this callable are pre-filled and an airflow step_operator is created
        within the dag. The dependencies to upstream steps are then
        configured.

        Finally, the dag is fully complete and can be returned.

        Args:
            sorted_steps: List of steps in the pipeline.
            pipeline: The pipeline to be executed.
            pb2_pipeline: The pipeline as a protobuf message.
            stack: The stack on which the pipeline will be deployed.
            runtime_configuration: The runtime configuration.

        Returns:
            The Airflow DAG.
        """
        import airflow
        from airflow.operators import python as airflow_python

        # Instantiate and configure airflow Dag with name and schedule
        airflow_dag = airflow.DAG(
            dag_id=pipeline.name,
            is_paused_upon_creation=False,
            **self._translate_schedule(runtime_configuration.schedule),
        )

        # Dictionary mapping step names to airflow_operators. This will be needed
        # to configure airflow operator dependencies
        step_name_to_airflow_operator = {}

        for step in sorted_steps:
            # Create callable that will be used by airflow to execute the step
            # within the orchestrated environment
            def _step_callable(step_instance: "BaseStep", **kwargs):
                if self.requires_resources_in_orchestration_environment(step):
                    logger.warning(
                        "Specifying step resources is not yet supported for "
                        "the Airflow orchestrator, ignoring resource "
                        "configuration for step %s.",
                        step.name,
                    )
                # Extract run name for the kwargs that will be passed to the
                # callable
                run_name = kwargs["ti"].get_dagrun().run_id
                self.run_step(
                    step=step_instance,
                    run_name=run_name,
                    pb2_pipeline=pb2_pipeline,
                )

            # Create airflow python operator that contains the step callable
            airflow_operator = airflow_python.PythonOperator(
                dag=airflow_dag,
                task_id=step.name,
                provide_context=True,
                python_callable=functools.partial(
                    _step_callable, step_instance=step
                ),
            )

            # Configure the current airflow operator to run after all upstream
            # operators finished executing
            step_name_to_airflow_operator[step.name] = airflow_operator
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                airflow_operator.set_upstream(
                    step_name_to_airflow_operator[upstream_step_name]
                )

        # Return the finished airflow dag
        return airflow_dag

    @root_validator(skip_on_failure=True)
    def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Sets Airflow home according to orchestrator UUID.

        Args:
            values: Dictionary containing all orchestrator attributes values.

        Returns:
            Dictionary containing all orchestrator attributes values and the airflow home.

        Raises:
            ValueError: If the orchestrator UUID is not set.
        """
        if "uuid" not in values:
            raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
        values["airflow_home"] = os.path.join(
            io_utils.get_global_config_directory(),
            AIRFLOW_ROOT_DIR,
            str(values["uuid"]),
        )
        return values

    @property
    def dags_directory(self) -> str:
        """Returns path to the airflow dags directory.

        Returns:
            Path to the airflow dags directory.
        """
        return os.path.join(self.airflow_home, "dags")

    @property
    def pid_file(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.airflow_home, "airflow_daemon.pid")

    @property
    def log_file(self) -> str:
        """Returns path to the airflow log file.

        Returns:
            str: Path to the airflow log file.
        """
        return os.path.join(self.airflow_home, "airflow_orchestrator.log")

    @property
    def password_file(self) -> str:
        """Returns path to the webserver password file.

        Returns:
            Path to the webserver password file.
        """
        return os.path.join(self.airflow_home, "standalone_admin_password.txt")

    def _set_env(self) -> None:
        """Sets environment variables to configure airflow."""
        os.environ["AIRFLOW_HOME"] = self.airflow_home
        os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
        os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
        os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
        # check the DAG folder every 10 seconds for new files
        os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"

    def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
        """Copies DAG module to the Airflow DAGs directory if not already present.

        Args:
            dag_filepath: Path to the file in which the DAG is defined.
        """
        dags_directory = io_utils.resolve_relative_path(self.dags_directory)

        if dags_directory == os.path.dirname(dag_filepath):
            logger.debug("File is already in airflow DAGs directory.")
        else:
            logger.debug(
                "Copying dag file '%s' to DAGs directory.", dag_filepath
            )
            destination_path = os.path.join(
                dags_directory, os.path.basename(dag_filepath)
            )
            if fileio.exists(destination_path):
                logger.info(
                    "File '%s' already exists, overwriting with new DAG file",
                    destination_path,
                )
            fileio.copy(dag_filepath, destination_path, overwrite=True)

    def _log_webserver_credentials(self) -> None:
        """Logs URL and credentials to log in to the airflow webserver.

        Raises:
            FileNotFoundError: If the password file does not exist.
        """
        if fileio.exists(self.password_file):
            with open(self.password_file) as file:
                password = file.read().strip()
        else:
            raise FileNotFoundError(
                f"Can't find password file '{self.password_file}'"
            )
        logger.info(
            "To inspect your DAGs, login to http://0.0.0.0:8080 "
            "with username: admin password: %s",
            password,
        )

    def runtime_options(self) -> Dict[str, Any]:
        """Runtime options for the airflow orchestrator.

        Returns:
            Runtime options dictionary.
        """
        return {DAG_FILEPATH_OPTION_KEY: None}

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

        Args:
            pipeline: Pipeline to be deployed.
            stack: Stack to be deployed.
            runtime_configuration: Runtime configuration for the pipeline.

        Raises:
            RuntimeError: If Airflow is not running or no DAG filepath runtime
                          option is provided.
        """
        if not self.is_running:
            raise RuntimeError(
                "Airflow orchestrator is currently not running. Run `zenml "
                "stack up` to provision resources for the active stack."
            )

        if Environment.in_notebook():
            raise RuntimeError(
                "Unable to run the Airflow orchestrator from within a "
                "notebook. Airflow requires a python file which contains a "
                "global Airflow DAG object and therefore does not work with "
                "notebooks. Please copy your ZenML pipeline code in a python "
                "file and try again."
            )

        try:
            dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
        except KeyError:
            raise RuntimeError(
                f"No DAG filepath found in runtime configuration. Make sure "
                f"to add the filepath to your airflow DAG file as a runtime "
                f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
            )

        self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)

    @property
    def is_running(self) -> bool:
        """Returns whether the airflow daemon is currently running.

        Returns:
            True if the daemon is running, False otherwise.

        Raises:
            RuntimeError: If port 8080 is occupied.
        """
        from airflow.cli.commands.standalone_command import StandaloneCommand
        from airflow.jobs.triggerer_job import TriggererJob

        daemon_running = daemon.check_if_daemon_is_running(self.pid_file)

        command = StandaloneCommand()
        webserver_port_open = command.port_open(8080)

        if not daemon_running:
            if webserver_port_open:
                raise RuntimeError(
                    "The airflow daemon does not seem to be running but "
                    "local port 8080 is occupied. Make sure the port is "
                    "available and try again."
                )

            # exit early so we don't check non-existing airflow databases
            return False

        # we can't use StandaloneCommand().is_ready() here as the
        # Airflow SequentialExecutor apparently does not send a heartbeat
        # while running a task which would result in this returning `False`
        # even if Airflow is running.
        airflow_running = webserver_port_open and command.job_running(
            TriggererJob
        )
        return airflow_running

    @property
    def is_provisioned(self) -> bool:
        """Returns whether the airflow daemon is currently running.

        Returns:
            True if the airflow daemon is running, False otherwise.
        """
        return self.is_running

    def provision(self) -> None:
        """Ensures that Airflow is running."""
        if self.is_running:
            logger.info("Airflow is already running.")
            self._log_webserver_credentials()
            return

        if not fileio.exists(self.dags_directory):
            io_utils.create_dir_recursive_if_not_exists(self.dags_directory)

        from airflow.cli.commands.standalone_command import StandaloneCommand

        try:
            command = StandaloneCommand()
            # Run the daemon with a working directory inside the current
            # zenml repo so the same repo will be used to run the DAGs
            daemon.run_as_daemon(
                command.run,
                pid_file=self.pid_file,
                log_file=self.log_file,
                working_directory=get_source_root_path(),
            )
            while not self.is_running:
                # Wait until the daemon started all the relevant airflow
                # processes
                time.sleep(0.1)
            self._log_webserver_credentials()
        except Exception as e:
            logger.error(e)
            logger.error(
                "An error occurred while starting the Airflow daemon. If you "
                "want to start it manually, use the commands described in the "
                "official Airflow quickstart guide for running Airflow locally."
            )
            self.deprovision()

    def deprovision(self) -> None:
        """Stops the airflow daemon if necessary and tears down resources."""
        if self.is_running:
            daemon.stop_daemon(self.pid_file)

        fileio.rmtree(self.airflow_home)
        logger.info("Airflow spun down.")
dags_directory: str property readonly

Returns path to the airflow dags directory.

Returns:

Type Description
str

Path to the airflow dags directory.

is_provisioned: bool property readonly

Returns whether the airflow daemon is currently running.

Returns:

Type Description
bool

True if the airflow daemon is running, False otherwise.

is_running: bool property readonly

Returns whether the airflow daemon is currently running.

Returns:

Type Description
bool

True if the daemon is running, False otherwise.

Exceptions:

Type Description
RuntimeError

If port 8080 is occupied.

log_file: str property readonly

Returns path to the airflow log file.

Returns:

Type Description
str

Path to the airflow log file.

password_file: str property readonly

Returns path to the webserver password file.

Returns:

Type Description
str

Path to the webserver password file.

pid_file: str property readonly

Returns path to the daemon PID file.

Returns:

Type Description
str

Path to the daemon PID file.

__init__(self, **values) special

Sets environment variables to configure airflow.

Parameters:

Name Type Description Default
**values Any

Values to set in the orchestrator.

{}
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
    """Sets environment variables to configure airflow.

    Args:
        **values: Values to set in the orchestrator.
    """
    super().__init__(**values)
    self._set_env()
deprovision(self)

Stops the airflow daemon if necessary and tears down resources.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
    """Stops the airflow daemon if necessary and tears down resources."""
    if self.is_running:
        daemon.stop_daemon(self.pid_file)

    fileio.rmtree(self.airflow_home)
    logger.info("Airflow spun down.")
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates an Airflow DAG as the intermediate representation for the pipeline.

This DAG will be loaded by airflow in the target environment and used for orchestration of the pipeline.

How it works:

A new airflow_dag is instantiated with the pipeline name and among others things the run schedule.

For each step of the pipeline a callable is created. This callable uses the run_step() method to execute the step. The parameters of this callable are pre-filled and an airflow step_operator is created within the dag. The dependencies to upstream steps are then configured.

Finally, the dag is fully complete and can be returned.

Parameters:

Name Type Description Default
sorted_steps List[zenml.steps.base_step.BaseStep]

List of steps in the pipeline.

required
pipeline BasePipeline

The pipeline to be executed.

required
pb2_pipeline Pipeline

The pipeline as a protobuf message.

required
stack Stack

The stack on which the pipeline will be deployed.

required
runtime_configuration RuntimeConfiguration

The runtime configuration.

required

Returns:

Type Description
Any

The Airflow DAG.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List[BaseStep],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates an Airflow DAG as the intermediate representation for the pipeline.

    This DAG will be loaded by airflow in the target environment
    and used for orchestration of the pipeline.

    How it works:
    -------------
    A new airflow_dag is instantiated with the pipeline name and among
    others things the run schedule.

    For each step of the pipeline a callable is created. This callable
    uses the run_step() method to execute the step. The parameters of
    this callable are pre-filled and an airflow step_operator is created
    within the dag. The dependencies to upstream steps are then
    configured.

    Finally, the dag is fully complete and can be returned.

    Args:
        sorted_steps: List of steps in the pipeline.
        pipeline: The pipeline to be executed.
        pb2_pipeline: The pipeline as a protobuf message.
        stack: The stack on which the pipeline will be deployed.
        runtime_configuration: The runtime configuration.

    Returns:
        The Airflow DAG.
    """
    import airflow
    from airflow.operators import python as airflow_python

    # Instantiate and configure airflow Dag with name and schedule
    airflow_dag = airflow.DAG(
        dag_id=pipeline.name,
        is_paused_upon_creation=False,
        **self._translate_schedule(runtime_configuration.schedule),
    )

    # Dictionary mapping step names to airflow_operators. This will be needed
    # to configure airflow operator dependencies
    step_name_to_airflow_operator = {}

    for step in sorted_steps:
        # Create callable that will be used by airflow to execute the step
        # within the orchestrated environment
        def _step_callable(step_instance: "BaseStep", **kwargs):
            if self.requires_resources_in_orchestration_environment(step):
                logger.warning(
                    "Specifying step resources is not yet supported for "
                    "the Airflow orchestrator, ignoring resource "
                    "configuration for step %s.",
                    step.name,
                )
            # Extract run name for the kwargs that will be passed to the
            # callable
            run_name = kwargs["ti"].get_dagrun().run_id
            self.run_step(
                step=step_instance,
                run_name=run_name,
                pb2_pipeline=pb2_pipeline,
            )

        # Create airflow python operator that contains the step callable
        airflow_operator = airflow_python.PythonOperator(
            dag=airflow_dag,
            task_id=step.name,
            provide_context=True,
            python_callable=functools.partial(
                _step_callable, step_instance=step
            ),
        )

        # Configure the current airflow operator to run after all upstream
        # operators finished executing
        step_name_to_airflow_operator[step.name] = airflow_operator
        upstream_step_names = self.get_upstream_step_names(
            step=step, pb2_pipeline=pb2_pipeline
        )
        for upstream_step_name in upstream_step_names:
            airflow_operator.set_upstream(
                step_name_to_airflow_operator[upstream_step_name]
            )

    # Return the finished airflow dag
    return airflow_dag
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

Parameters:

Name Type Description Default
pipeline BasePipeline

Pipeline to be deployed.

required
stack Stack

Stack to be deployed.

required
runtime_configuration RuntimeConfiguration

Runtime configuration for the pipeline.

required

Exceptions:

Type Description
RuntimeError

If Airflow is not running or no DAG filepath runtime option is provided.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

    Args:
        pipeline: Pipeline to be deployed.
        stack: Stack to be deployed.
        runtime_configuration: Runtime configuration for the pipeline.

    Raises:
        RuntimeError: If Airflow is not running or no DAG filepath runtime
                      option is provided.
    """
    if not self.is_running:
        raise RuntimeError(
            "Airflow orchestrator is currently not running. Run `zenml "
            "stack up` to provision resources for the active stack."
        )

    if Environment.in_notebook():
        raise RuntimeError(
            "Unable to run the Airflow orchestrator from within a "
            "notebook. Airflow requires a python file which contains a "
            "global Airflow DAG object and therefore does not work with "
            "notebooks. Please copy your ZenML pipeline code in a python "
            "file and try again."
        )

    try:
        dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
    except KeyError:
        raise RuntimeError(
            f"No DAG filepath found in runtime configuration. Make sure "
            f"to add the filepath to your airflow DAG file as a runtime "
            f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
        )

    self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)

Ensures that Airflow is running.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
    """Ensures that Airflow is running."""
    if self.is_running:
        logger.info("Airflow is already running.")
        self._log_webserver_credentials()
        return

    if not fileio.exists(self.dags_directory):
        io_utils.create_dir_recursive_if_not_exists(self.dags_directory)

    from airflow.cli.commands.standalone_command import StandaloneCommand

    try:
        command = StandaloneCommand()
        # Run the daemon with a working directory inside the current
        # zenml repo so the same repo will be used to run the DAGs
        daemon.run_as_daemon(
            command.run,
            pid_file=self.pid_file,
            log_file=self.log_file,
            working_directory=get_source_root_path(),
        )
        while not self.is_running:
            # Wait until the daemon started all the relevant airflow
            # processes
            time.sleep(0.1)
        self._log_webserver_credentials()
    except Exception as e:
        logger.error(e)
        logger.error(
            "An error occurred while starting the Airflow daemon. If you "
            "want to start it manually, use the commands described in the "
            "official Airflow quickstart guide for running Airflow locally."
        )
        self.deprovision()
runtime_options(self)

Runtime options for the airflow orchestrator.

Returns:

Type Description
Dict[str, Any]

Runtime options dictionary.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def runtime_options(self) -> Dict[str, Any]:
    """Runtime options for the airflow orchestrator.

    Returns:
        Runtime options dictionary.
    """
    return {DAG_FILEPATH_OPTION_KEY: None}
set_airflow_home(values) classmethod

Sets Airflow home according to orchestrator UUID.

Parameters:

Name Type Description Default
values Dict[str, Any]

Dictionary containing all orchestrator attributes values.

required

Returns:

Type Description
Dict[str, Any]

Dictionary containing all orchestrator attributes values and the airflow home.

Exceptions:

Type Description
ValueError

If the orchestrator UUID is not set.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Sets Airflow home according to orchestrator UUID.

    Args:
        values: Dictionary containing all orchestrator attributes values.

    Returns:
        Dictionary containing all orchestrator attributes values and the airflow home.

    Raises:
        ValueError: If the orchestrator UUID is not set.
    """
    if "uuid" not in values:
        raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
    values["airflow_home"] = os.path.join(
        io_utils.get_global_config_directory(),
        AIRFLOW_ROOT_DIR,
        str(values["uuid"]),
    )
    return values

aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = ["boto3==1.21.0", "sagemaker==2.82.2"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AWS_SECRET_MANAGER_FLAVOR,
                source="zenml.integrations.aws.secrets_managers"
                ".AWSSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AWS_CONTAINER_REGISTRY_FLAVOR,
                source="zenml.integrations.aws.container_registries"
                ".AWSContainerRegistry",
                type=StackComponentType.CONTAINER_REGISTRY,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.aws.step_operators"
                ".SagemakerStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AWS_SECRET_MANAGER_FLAVOR,
            source="zenml.integrations.aws.secrets_managers"
            ".AWSSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AWS_CONTAINER_REGISTRY_FLAVOR,
            source="zenml.integrations.aws.container_registries"
            ".AWSContainerRegistry",
            type=StackComponentType.CONTAINER_REGISTRY,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.aws.step_operators"
            ".SagemakerStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry) pydantic-model

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    # Class Configuration
    FLAVOR: ClassVar[str] = AWS_CONTAINER_REGISTRY_FLAVOR

    @validator("uri")
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri

    def _get_region(self) -> str:
        """Parses the AWS region from the registry URI.

        Raises:
            RuntimeError: If the region parsing fails due to an invalid URI.

        Returns:
            The region string.
        """
        match = re.fullmatch(r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.uri)
        if not match:
            raise RuntimeError(
                f"Unable to parse region from ECR URI {self.uri}."
            )

        return match.group(1)

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        response = boto3.client(
            "ecr", region_name=self._get_region()
        ).describe_repositories()
        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
        if not repo_exists:
            match = re.search(f"{self.uri}/(.*):.*", image_name)
            if not match:
                raise ValueError(f"Invalid docker image name '{image_name}'.")

            repo_name = match.group(1)
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for example run a pipeline "
            "using our Kubeflow orchestrator, ZenML will automatically build a "
            f"docker image called `{self.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create the "
            f"repository `zenml-kubeflow` inside your amazon registry."
        )
post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    response = boto3.client(
        "ecr", region_name=self._get_region()
    ).describe_repositories()
    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
    if not repo_exists:
        match = re.search(f"{self.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")

        repo_name = match.group(1)
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri

secrets_managers special

AWS Secrets Manager.

aws_secrets_manager

Implementation of the AWS Secrets Manager integration.

AWSSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the AWS secrets manager.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
    """Class to interact with the AWS secrets manager."""

    region_name: str

    # Class configuration
    FLAVOR: ClassVar[str] = AWS_SECRET_MANAGER_FLAVOR
    SUPPORTS_SCOPING: ClassVar[bool] = True
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _validate_scope(
        cls,
        scope: SecretsManagerScope,
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            cls.validate_secret_name_or_namespace(namespace)

    @classmethod
    def _ensure_client_connected(cls, region_name: str) -> None:
        """Ensure that the client is connected to the AWS secrets manager.

        Args:
            region_name: the AWS region name
        """
        if cls.CLIENT is None:
            # Create a Secrets Manager client
            session = boto3.session.Session()
            cls.CLIENT = session.client(
                service_name="secretsmanager", region_name=region_name
            )

    @classmethod
    def validate_secret_name_or_namespace(cls, name: str) -> None:
        """Validate a secret name or namespace.

        AWS secret names must contain only alphanumeric characters and the
        characters /_+=.@-. The `/` character is only used internally to delimit
        scopes.

        Args:
            name: the secret name or namespace

        Raises:
            ValueError: if the secret name or namespace is invalid
        """
        if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. Must contain "
                f"only alphanumeric characters and the characters _+=.@-."
            )

    def _get_secret_tags(
        self, secret: BaseSecretSchema
    ) -> List[Dict[str, str]]:
        """Return a list of AWS secret tag values for a given secret.

        Args:
            secret: the secret object

        Returns:
            A list of AWS secret tag values
        """
        metadata = self._get_secret_metadata(secret)
        return [{"Key": k, "Value": v} for k, v in metadata.items()]

    def _get_secret_scope_filters(
        self,
        secret_name: Optional[str] = None,
    ) -> List[Dict[str, Any]]:
        """Return a list of AWS filters for the entire scope or just a scoped secret.

        These filters can be used when querying the AWS Secrets Manager
        for all secrets or for a single secret available in the configured
        scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html

        Example AWS filters for all secrets in the current (namespace) scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Example AWS filters for a particular secret in the current (namespace)
        scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_secret_name"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_secret"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of AWS filters uniquely identifying all secrets
            or a named secret within the configured scope.
        """
        metadata = self._get_secret_scope_metadata(secret_name)
        filters: List[Dict[str, Any]] = []
        for k, v in metadata.items():
            filters.append(
                {
                    "Key": "tag-key",
                    "Values": [
                        k,
                    ],
                }
            )
            filters.append(
                {
                    "Key": "tag-value",
                    "Values": [
                        str(v),
                    ],
                }
            )

        return filters

    def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
        """List all secrets matching a name.

        This method lists all the secrets in the current scope without loading
        their contents. An optional secret name can be supplied to filter out
        all but a single secret identified by name.

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of secret names in the current scope and the optional
            secret name.
        """
        self._ensure_client_connected(self.region_name)

        filters: List[Dict[str, Any]] = []
        prefix: Optional[str] = None
        if self.scope == SecretsManagerScope.NONE:
            # unscoped (legacy) secrets don't have tags. We want to filter out
            # non-legacy secrets
            filters = [
                {
                    "Key": "tag-key",
                    "Values": [
                        "!zenml_scope",
                    ],
                },
            ]
            if secret_name:
                prefix = secret_name
        else:
            filters = self._get_secret_scope_filters()
            if secret_name:
                prefix = self._get_scoped_secret_name(secret_name)
            else:
                # add the name prefix to the filters to account for the fact
                # that AWS does not do exact matching but prefix-matching on the
                # filters
                prefix = self._get_scoped_secret_name_prefix()

        if prefix:
            filters.append(
                {
                    "Key": "name",
                    "Values": [
                        f"{prefix}",
                    ],
                }
            )

        # TODO [ENG-720]: Deal with pagination in the aws secret manager when
        #  listing all secrets
        # TODO [ENG-721]: take out this magic maxresults number
        response = self.CLIENT.list_secrets(MaxResults=100, Filters=filters)
        results = []
        for secret in response["SecretList"]:
            name = self._get_unscoped_secret_name(secret["Name"])
            # keep only the names that are in scope and filter by secret name,
            # if one was given
            if name and (not secret_name or secret_name == name):
                results.append(name)

        return results

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self.validate_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.region_name)

        if self._list_secrets(secret.name):
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        secret_value = json.dumps(secret_to_dict(secret, encode=False))
        kwargs: Dict[str, Any] = {
            "Name": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
            "Tags": self._get_secret_tags(secret),
        }

        self.CLIENT.create_secret(**kwargs)

        logger.debug("Created AWS secret: %s", kwargs["Name"])

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets a secret.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name_or_namespace(secret_name)
        self._ensure_client_connected(self.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        get_secret_value_response = self.CLIENT.get_secret_value(
            SecretId=self._get_scoped_secret_name(secret_name)
        )
        if "SecretString" not in get_secret_value_response:
            get_secret_value_response = None

        return secret_from_dict(
            json.loads(get_secret_value_response["SecretString"]),
            secret_name=secret_name,
            decode=False,
        )

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        return self._list_secrets()

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.region_name)

        if not self._list_secrets(secret.name):
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        secret_value = json.dumps(secret_to_dict(secret))

        kwargs = {
            "SecretId": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
        }

        self.CLIENT.put_secret_value(**kwargs)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret does not exist
        """
        self._ensure_client_connected(self.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets.

        This method will force delete all your secrets. You will not be able to
        recover them once this method is called.
        """
        self._ensure_client_connected(self.region_name)
        for secret_name in self._list_secrets():
            self.CLIENT.delete_secret(
                SecretId=self._get_scoped_secret_name(secret_name),
                ForceDeleteWithoutRecovery=True,
            )
delete_all_secrets(self)

Delete all existing secrets.

This method will force delete all your secrets. You will not be able to recover them once this method is called.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets.

    This method will force delete all your secrets. You will not be able to
    recover them once this method is called.
    """
    self._ensure_client_connected(self.region_name)
    for secret_name in self._list_secrets():
        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret does not exist
    """
    self._ensure_client_connected(self.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    self.CLIENT.delete_secret(
        SecretId=self._get_scoped_secret_name(secret_name),
        ForceDeleteWithoutRecovery=True,
    )
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    return self._list_secrets()
get_secret(self, secret_name)

Gets a secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets a secret.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name_or_namespace(secret_name)
    self._ensure_client_connected(self.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    get_secret_value_response = self.CLIENT.get_secret_value(
        SecretId=self._get_scoped_secret_name(secret_name)
    )
    if "SecretString" not in get_secret_value_response:
        get_secret_value_response = None

    return secret_from_dict(
        json.loads(get_secret_value_response["SecretString"]),
        secret_name=secret_name,
        decode=False,
    )
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self.validate_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.region_name)

    if self._list_secrets(secret.name):
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    secret_value = json.dumps(secret_to_dict(secret, encode=False))
    kwargs: Dict[str, Any] = {
        "Name": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
        "Tags": self._get_secret_tags(secret),
    }

    self.CLIENT.create_secret(**kwargs)

    logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.region_name)

    if not self._list_secrets(secret.name):
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    secret_value = json.dumps(secret_to_dict(secret))

    kwargs = {
        "SecretId": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
    }

    self.CLIENT.put_secret_value(**kwargs)
validate_secret_name_or_namespace(name) classmethod

Validate a secret name or namespace.

AWS secret names must contain only alphanumeric characters and the characters /_+=.@-. The / character is only used internally to delimit scopes.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
    """Validate a secret name or namespace.

    AWS secret names must contain only alphanumeric characters and the
    characters /_+=.@-. The `/` character is only used internally to delimit
    scopes.

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the characters _+=.@-."
        )

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator) pydantic-model

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

instance_type str

The type of the compute instance where jobs will run.

base_image Optional[str]

The base image to use for building the docker image that will be executed.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        instance_type: The type of the compute instance where jobs will run.
        base_image: The base image to use for building the docker
            image that will be executed.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
        experiment_name: The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
    """

    role: str
    instance_type: str

    base_image: Optional[str] = None
    bucket: Optional[str] = None
    experiment_name: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Returns:
            A validator that checks that the stack contains a container registry.
        """

        def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
            return (
                stack.orchestrator.FLAVOR == "local",
                "Local orchestrator is required",
            )

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_ensure_local_orchestrator,
        )

    def _build_docker_image(
        self,
        pipeline_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> str:
        repo = Repository()
        container_registry = repo.active_stack.container_registry

        if not container_registry:
            raise RuntimeError("Missing container registry")

        registry_uri = container_registry.uri.rstrip("/")
        image_name = f"{registry_uri}/zenml-sagemaker:{pipeline_name}"

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            entrypoint=" ".join(entrypoint_command),
            requirements=set(requirements),
            base_image=self.base_image,
        )
        container_registry.push_image(image_name)
        return docker_utils.get_image_digest(image_name) or image_name

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Launches a step on Sagemaker.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.
            resource_configuration: The resource configuration for this step.
        """
        image_name = self._build_docker_image(
            pipeline_name=pipeline_name,
            requirements=requirements,
            entrypoint_command=entrypoint_command,
        )

        if not resource_configuration.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the SageMaker step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different instance type like this: "
                "`zenml step-operator update %s "
                "--instance_type=<INSTANCE_TYPE>`",
                self.name,
            )

        session = sagemaker.Session(default_bucket=self.bucket)
        estimator = sagemaker.estimator.Estimator(
            image_name,
            self.role,
            instance_count=1,
            instance_type=self.instance_type,
            sagemaker_session=session,
        )

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_run_name = run_name.replace("_", "-")

        experiment_config = {}
        if self.experiment_name:
            experiment_config = {
                "ExperimentName": self.experiment_name,
                "TrialName": sanitized_run_name,
            }

        estimator.fit(
            wait=True,
            experiment_config=experiment_config,
            job_name=sanitized_run_name,
        )
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a container registry.

launch(self, pipeline_name, run_name, requirements, entrypoint_command, resource_configuration)

Launches a step on Sagemaker.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required
resource_configuration ResourceConfiguration

The resource configuration for this step.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
    resource_configuration: "ResourceConfiguration",
) -> None:
    """Launches a step on Sagemaker.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.
        resource_configuration: The resource configuration for this step.
    """
    image_name = self._build_docker_image(
        pipeline_name=pipeline_name,
        requirements=requirements,
        entrypoint_command=entrypoint_command,
    )

    if not resource_configuration.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    session = sagemaker.Session(default_bucket=self.bucket)
    estimator = sagemaker.estimator.Estimator(
        image_name,
        self.role,
        instance_count=1,
        instance_type=self.instance_type,
        sagemaker_session=session,
    )

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_run_name = run_name.replace("_", "-")

    experiment_config = {}
    if self.experiment_name:
        experiment_config = {
            "ExperimentName": self.experiment_name,
            "TrialName": sanitized_run_name,
        }

    estimator.fit(
        wait=True,
        experiment_config=experiment_config,
        job_name=sanitized_run_name,
    )

azure special

Initialization of the ZenML Azure integration.

The Azure integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores, and an io module to handle file operations on Azure Blob Storage. The Azure Step Operator integration submodule provides a way to run ZenML steps in AzureML.

AzureIntegration (Integration)

Definition of Azure integration for ZenML.

Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
    """Definition of Azure integration for ZenML."""

    NAME = AZURE
    REQUIREMENTS = [
        "adlfs==2021.10.0",
        "azure-keyvault-keys",
        "azure-keyvault-secrets",
        "azure-identity",
        "azureml-core==1.42.0.post1",
    ]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declares the flavors for the integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AZURE_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.azure.artifact_stores"
                ".AzureArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AZURE_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.azure.secrets_managers"
                ".AzureSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AZUREML_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.azure.step_operators"
                ".AzureMLStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declares the flavors for the integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/azure/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declares the flavors for the integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AZURE_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.azure.artifact_stores"
            ".AzureArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AZURE_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.azure.secrets_managers"
            ".AzureSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AZUREML_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.azure.step_operators"
            ".AzureMLStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

artifact_stores special

Initialization of the Azure Artifact Store integration.

azure_artifact_store

Implementation of the Azure Artifact Store integration.

AzureArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for Microsoft Azure based artifacts.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for Microsoft Azure based artifacts."""

    _filesystem: Optional[adlfs.AzureBlobFileSystem] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AZURE_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"}

    @property
    def filesystem(self) -> adlfs.AzureBlobFileSystem:
        """The adlfs filesystem to access this artifact store.

        Returns:
            The adlfs filesystem to access this artifact store.
        """
        if not self._filesystem:
            secret = self.get_authentication_secret(
                expected_schema_type=AzureSecretSchema
            )
            credentials = secret.content if secret else {}

            self._filesystem = adlfs.AzureBlobFileSystem(
                **credentials,
                anon=False,
                use_listings_cache=False,
            )
        return self._filesystem

    @classmethod
    def _split_path(cls, path: PathType) -> Tuple[str, str]:
        """Splits a path into the filesystem prefix and remainder.

        Example:
        ```python
        prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
        print(prefix, remainder)  # "az://" "my_container/test.txt"
        ```

        Args:
            path: The path to split.

        Returns:
            A tuple of the filesystem prefix and the remainder.
        """
        path = convert_to_str(path)
        prefix = ""
        for potential_prefix in cls.SUPPORTED_SCHEMES:
            if path.startswith(potential_prefix):
                prefix = potential_prefix
                path = path[len(potential_prefix) :]
                break

        return prefix, path

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )

        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
            in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        prefix, _ = self._split_path(pattern)
        return [
            f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
        ]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path to list.

        Returns:
            A list of files in the given directory.
        """
        _, path = self._split_path(path)

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a dictionary returned by the Azure filesystem.

            Args:
                file_dict: A dictionary returned by the Azure filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["name"])
            base_name = file_path[len(path) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: The path to get stat info for.

        Returns:
            Stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        prefix, _ = self._split_path(top)
        for (
            directory,
            subdirectories,
            files,
        ) in self.filesystem.walk(path=top):
            yield f"{prefix}{directory}", subdirectories, files
filesystem: AzureBlobFileSystem property readonly

The adlfs filesystem to access this artifact store.

Returns:

Type Description
AzureBlobFileSystem

The adlfs filesystem to access this artifact store.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )

    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
        in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    prefix, _ = self._split_path(pattern)
    return [
        f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
    ]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of files in the given directory.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path to list.

    Returns:
        A list of files in the given directory.
    """
    _, path = self._split_path(path)

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a dictionary returned by the Azure filesystem.

        Args:
            file_dict: A dictionary returned by the Azure filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["name"])
        base_name = file_path[len(path) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to remove.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

Stat info.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: The path to get stat info for.

    Returns:
        Stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    prefix, _ = self._split_path(top)
    for (
        directory,
        subdirectories,
        files,
    ) in self.filesystem.walk(path=top):
        yield f"{prefix}{directory}", subdirectories, files

secrets_managers special

Initialization of the Azure Secrets Manager integration.

azure_secrets_manager

Implementation of the Azure Secrets Manager integration.

AzureSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the Azure secrets manager.

Attributes:

Name Type Description
key_vault_name str

Name of an Azure Key Vault that this secrets manager will use to store secrets.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
class AzureSecretsManager(BaseSecretsManager):
    """Class to interact with the Azure secrets manager.

    Attributes:
        key_vault_name: Name of an Azure Key Vault that this secrets manager
            will use to store secrets.
    """

    key_vault_name: str

    # Class configuration
    FLAVOR: ClassVar[str] = AZURE_SECRETS_MANAGER_FLAVOR
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _ensure_client_connected(cls, vault_name: str) -> None:
        if cls.CLIENT is None:
            KVUri = f"https://{vault_name}.vault.azure.net"

            credential = DefaultAzureCredential()
            cls.CLIENT = SecretClient(vault_url=KVUri, credential=credential)

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self._ensure_client_connected(self.key_vault_name)

        if secret.name in self.get_all_secret_keys():
            raise SecretExistsError(
                f"A Secret with the name '{secret.name}' already exists."
            )

        self.update_secret(secret)

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Get a secret by its name.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            RuntimeError: if the secret does not exist
            ValueError: if the secret is named 'name'
        """
        self._ensure_client_connected(self.key_vault_name)

        secret_contents = {}
        zenml_schema_name = ""

        for secret_property in self.CLIENT.list_properties_of_secrets():
            tags = secret_property.tags

            if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                secret_key = tags.get(ZENML_KEY_NAME)
                if not secret_key:
                    raise ValueError("Missing secret key tag.")

                if secret_key == "name":
                    raise ValueError("The secret's key cannot be 'name'.")

                response = self.CLIENT.get_secret(secret_property.name)
                secret_contents[secret_key] = response.value

                zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)

        if not secret_contents:
            raise RuntimeError(f"No secrets found within the {secret_name}")

        secret_contents["name"] = secret_name

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=zenml_schema_name
        )
        return secret_schema(**secret_contents)

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        self._ensure_client_connected(self.key_vault_name)

        set_of_secrets = set()

        for secret_property in self.CLIENT.list_properties_of_secrets():
            tags = secret_property.tags
            if tags and ZENML_GROUP_KEY in tags:
                set_of_secrets.add(tags.get(ZENML_GROUP_KEY))

        return list(set_of_secrets)

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret by creating new versions of the existing secrets.

        Args:
            secret: the secret to update
        """
        self._ensure_client_connected(self.key_vault_name)

        for key, value in secret.content.items():
            encoded_key = base64.b64encode(
                f"{secret.name}-{key}".encode()
            ).hex()
            azure_secret_name = f"zenml-{encoded_key}"

            self.CLIENT.set_secret(azure_secret_name, value)
            self.CLIENT.update_secret_properties(
                azure_secret_name,
                tags={
                    ZENML_GROUP_KEY: secret.name,
                    ZENML_KEY_NAME: key,
                    ZENML_SCHEMA_NAME: secret.TYPE,
                },
            )

            logger.debug("Wrote secret: %s", azure_secret_name)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret. by name.

        In Azure a secret is a single k-v pair. Within ZenML a secret is a
        collection of k-v pairs. As such, deleting a secret will iterate through
        all secrets and delete the ones with the secret_name as label.

        Args:
            secret_name: the name of the secret to delete
        """
        self._ensure_client_connected(self.key_vault_name)

        # Go through all Azure secrets and delete the ones with the secret_name
        #  as label.
        for secret_property in self.CLIENT.list_properties_of_secrets():
            response = self.CLIENT.get_secret(secret_property.name)
            tags = response.properties.tags
            if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                self.CLIENT.begin_delete_secret(secret_property.name).result()

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_connected(self.key_vault_name)

        # List all secrets.
        for secret_property in self.CLIENT.list_properties_of_secrets():
            response = self.CLIENT.get_secret(secret_property.name)
            tags = response.properties.tags
            if tags and (ZENML_GROUP_KEY in tags or ZENML_SCHEMA_NAME in tags):
                logger.info(
                    "Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
                    secret_property.name,
                    tags.get(ZENML_GROUP_KEY),
                )
                self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_connected(self.key_vault_name)

    # List all secrets.
    for secret_property in self.CLIENT.list_properties_of_secrets():
        response = self.CLIENT.get_secret(secret_property.name)
        tags = response.properties.tags
        if tags and (ZENML_GROUP_KEY in tags or ZENML_SCHEMA_NAME in tags):
            logger.info(
                "Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
                secret_property.name,
                tags.get(ZENML_GROUP_KEY),
            )
            self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_secret(self, secret_name)

Delete an existing secret. by name.

In Azure a secret is a single k-v pair. Within ZenML a secret is a collection of k-v pairs. As such, deleting a secret will iterate through all secrets and delete the ones with the secret_name as label.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret. by name.

    In Azure a secret is a single k-v pair. Within ZenML a secret is a
    collection of k-v pairs. As such, deleting a secret will iterate through
    all secrets and delete the ones with the secret_name as label.

    Args:
        secret_name: the name of the secret to delete
    """
    self._ensure_client_connected(self.key_vault_name)

    # Go through all Azure secrets and delete the ones with the secret_name
    #  as label.
    for secret_property in self.CLIENT.list_properties_of_secrets():
        response = self.CLIENT.get_secret(secret_property.name)
        tags = response.properties.tags
        if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
            self.CLIENT.begin_delete_secret(secret_property.name).result()
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    self._ensure_client_connected(self.key_vault_name)

    set_of_secrets = set()

    for secret_property in self.CLIENT.list_properties_of_secrets():
        tags = secret_property.tags
        if tags and ZENML_GROUP_KEY in tags:
            set_of_secrets.add(tags.get(ZENML_GROUP_KEY))

    return list(set_of_secrets)
get_secret(self, secret_name)

Get a secret by its name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
RuntimeError

if the secret does not exist

ValueError

if the secret is named 'name'

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Get a secret by its name.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        RuntimeError: if the secret does not exist
        ValueError: if the secret is named 'name'
    """
    self._ensure_client_connected(self.key_vault_name)

    secret_contents = {}
    zenml_schema_name = ""

    for secret_property in self.CLIENT.list_properties_of_secrets():
        tags = secret_property.tags

        if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
            secret_key = tags.get(ZENML_KEY_NAME)
            if not secret_key:
                raise ValueError("Missing secret key tag.")

            if secret_key == "name":
                raise ValueError("The secret's key cannot be 'name'.")

            response = self.CLIENT.get_secret(secret_property.name)
            secret_contents[secret_key] = response.value

            zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)

    if not secret_contents:
        raise RuntimeError(f"No secrets found within the {secret_name}")

    secret_contents["name"] = secret_name

    secret_schema = SecretSchemaClassRegistry.get_class(
        secret_schema=zenml_schema_name
    )
    return secret_schema(**secret_contents)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self._ensure_client_connected(self.key_vault_name)

    if secret.name in self.get_all_secret_keys():
        raise SecretExistsError(
            f"A Secret with the name '{secret.name}' already exists."
        )

    self.update_secret(secret)
update_secret(self, secret)

Update an existing secret by creating new versions of the existing secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret by creating new versions of the existing secrets.

    Args:
        secret: the secret to update
    """
    self._ensure_client_connected(self.key_vault_name)

    for key, value in secret.content.items():
        encoded_key = base64.b64encode(
            f"{secret.name}-{key}".encode()
        ).hex()
        azure_secret_name = f"zenml-{encoded_key}"

        self.CLIENT.set_secret(azure_secret_name, value)
        self.CLIENT.update_secret_properties(
            azure_secret_name,
            tags={
                ZENML_GROUP_KEY: secret.name,
                ZENML_KEY_NAME: key,
                ZENML_SCHEMA_NAME: secret.TYPE,
            },
        )

        logger.debug("Wrote secret: %s", azure_secret_name)

step_operators special

Initialization of AzureML Step Operator integration.

azureml_step_operator

Implementation of the ZenML AzureML Step Operator.

AzureMLStepOperator (BaseStepOperator) pydantic-model

Step operator to run a step on AzureML.

This class defines code that can set up an AzureML environment and run the ZenML entrypoint command in it.

Attributes:

Name Type Description
subscription_id str

The Azure account's subscription ID

resource_group str

The resource group to which the AzureML workspace is deployed.

workspace_name str

The name of the AzureML Workspace.

compute_target_name str

The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already.

environment_name Optional[str]

The name of the environment if there already exists one.

docker_base_image Optional[str]

The custom docker base image that the environment should use.

tenant_id Optional[str]

The Azure Tenant ID.

service_principal_id Optional[str]

The ID for the service principal that is created to allow apps to access secure resources.

service_principal_password Optional[str]

Password for the service principal.

Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator):
    """Step operator to run a step on AzureML.

    This class defines code that can set up an AzureML environment and run the
    ZenML entrypoint command in it.

    Attributes:
        subscription_id: The Azure account's subscription ID
        resource_group: The resource group to which the AzureML workspace
            is deployed.
        workspace_name: The name of the AzureML Workspace.
        compute_target_name: The name of the configured ComputeTarget.
            An instance of it has to be created on the portal if it doesn't
            exist already.
        environment_name: The name of the environment if there
            already exists one.
        docker_base_image: The custom docker base image that the
            environment should use.
        tenant_id: The Azure Tenant ID.
        service_principal_id: The ID for the service principal that is created
            to allow apps to access secure resources.
        service_principal_password: Password for the service principal.
    """

    subscription_id: str
    resource_group: str
    workspace_name: str
    compute_target_name: str

    # Environment
    environment_name: Optional[str] = None
    docker_base_image: Optional[str] = None

    # Service principal authentication
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
    tenant_id: Optional[str] = None
    service_principal_id: Optional[str] = None
    service_principal_password: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AZUREML_STEP_OPERATOR_FLAVOR

    def _get_authentication(self) -> Optional[AbstractAuthentication]:
        """Returns the authentication object for the AzureML environment.

        Returns:
            The authentication object for the AzureML environment.
        """
        if (
            self.tenant_id
            and self.service_principal_id
            and self.service_principal_password
        ):
            return ServicePrincipalAuthentication(
                tenant_id=self.tenant_id,
                service_principal_id=self.service_principal_id,
                service_principal_password=self.service_principal_password,
            )
        return None

    def _prepare_environment(
        self, workspace: Workspace, requirements: List[str], run_name: str
    ) -> Environment:
        """Prepares the environment in which Azure will run all jobs.

        Args:
            workspace: The AzureML Workspace that has configuration
                for a storage account, container registry among other
                things.
            requirements: The list of requirements to be installed
                in the environment.
            run_name: The name of the pipeline run that can be used
                for naming environments and runs.

        Returns:
            The AzureML Environment object.
        """
        if self.environment_name:
            environment = Environment.get(
                workspace=workspace, name=self.environment_name
            )
            if not environment.python.conda_dependencies:
                environment.python.conda_dependencies = (
                    CondaDependencies.create(
                        python_version=ZenMLEnvironment.python_version()
                    )
                )

            for requirement in requirements:
                environment.python.conda_dependencies.add_pip_package(
                    requirement
                )
        else:
            environment = Environment(name=f"zenml-{run_name}")
            environment.python.conda_dependencies = CondaDependencies.create(
                pip_packages=requirements,
                python_version=ZenMLEnvironment.python_version(),
            )

            if self.docker_base_image:
                # replace the default azure base image
                environment.docker.base_image = self.docker_base_image

        environment_variables = {
            "ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
        }
        # set credentials to access azure storage
        for key in [
            "AZURE_STORAGE_ACCOUNT_KEY",
            "AZURE_STORAGE_ACCOUNT_NAME",
            "AZURE_STORAGE_CONNECTION_STRING",
            "AZURE_STORAGE_SAS_TOKEN",
        ]:
            value = os.getenv(key)
            if value:
                environment_variables[key] = value

        environment_variables[
            ENV_ZENML_CONFIG_PATH
        ] = f"./{CONTAINER_ZENML_CONFIG_DIR}"

        environment.environment_variables = environment_variables
        return environment

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Launches a step on AzureML.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.
            resource_configuration: The resource configuration for this step.
        """
        if not resource_configuration.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the AzureML step operator. If you want to run this step "
                "operator on specific resources, you can do so by creating an "
                "Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
                "with a specific machine type and then updating this step "
                "operator: `zenml step-operator update %s "
                "--compute_target_name=<COMPUTE_TARGET_NAME>`",
                self.name,
            )

        workspace = Workspace.get(
            subscription_id=self.subscription_id,
            resource_group=self.resource_group,
            name=self.workspace_name,
            auth=self._get_authentication(),
        )

        source_directory = get_source_root_path()
        config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
        try:

            # Save a copy of the current global configuration with the
            # active profile contents into the build context, to have
            # the configured stacks accessible from within the Azure ML
            # environment.
            load_config_path = PurePosixPath(f"./{CONTAINER_ZENML_CONFIG_DIR}")
            GlobalConfiguration().copy_active_configuration(
                config_path,
                load_config_path=load_config_path,
            )

            environment = self._prepare_environment(
                workspace=workspace,
                requirements=requirements,
                run_name=run_name,
            )
            compute_target = ComputeTarget(
                workspace=workspace, name=self.compute_target_name
            )

            run_config = ScriptRunConfig(
                source_directory=source_directory,
                environment=environment,
                compute_target=compute_target,
                command=entrypoint_command,
            )

            experiment = Experiment(workspace=workspace, name=pipeline_name)
            run = experiment.submit(config=run_config)

        finally:
            # Clean up the temporary build files
            fileio.rmtree(config_path)

        run.display_name = run_name
        run.wait_for_completion(show_output=True)
launch(self, pipeline_name, run_name, requirements, entrypoint_command, resource_configuration)

Launches a step on AzureML.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required
resource_configuration ResourceConfiguration

The resource configuration for this step.

required
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
    resource_configuration: "ResourceConfiguration",
) -> None:
    """Launches a step on AzureML.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.
        resource_configuration: The resource configuration for this step.
    """
    if not resource_configuration.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the AzureML step operator. If you want to run this step "
            "operator on specific resources, you can do so by creating an "
            "Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
            "with a specific machine type and then updating this step "
            "operator: `zenml step-operator update %s "
            "--compute_target_name=<COMPUTE_TARGET_NAME>`",
            self.name,
        )

    workspace = Workspace.get(
        subscription_id=self.subscription_id,
        resource_group=self.resource_group,
        name=self.workspace_name,
        auth=self._get_authentication(),
    )

    source_directory = get_source_root_path()
    config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
    try:

        # Save a copy of the current global configuration with the
        # active profile contents into the build context, to have
        # the configured stacks accessible from within the Azure ML
        # environment.
        load_config_path = PurePosixPath(f"./{CONTAINER_ZENML_CONFIG_DIR}")
        GlobalConfiguration().copy_active_configuration(
            config_path,
            load_config_path=load_config_path,
        )

        environment = self._prepare_environment(
            workspace=workspace,
            requirements=requirements,
            run_name=run_name,
        )
        compute_target = ComputeTarget(
            workspace=workspace, name=self.compute_target_name
        )

        run_config = ScriptRunConfig(
            source_directory=source_directory,
            environment=environment,
            compute_target=compute_target,
            command=entrypoint_command,
        )

        experiment = Experiment(workspace=workspace, name=pipeline_name)
        run = experiment.submit(config=run_config)

    finally:
        # Clean up the temporary build files
        fileio.rmtree(config_path)

    run.display_name = run_name
    run.wait_for_completion(show_output=True)

constants

Constants for ZenML integrations.

dash special

Initialization of the Dash integration.

DashIntegration (Integration)

Definition of Dash integration for ZenML.

Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
    """Definition of Dash integration for ZenML."""

    NAME = DASH
    REQUIREMENTS = [
        "dash>=2.0.0",
        "dash-cytoscape>=0.3.0",
        "dash-bootstrap-components>=1.0.1",
        "jupyter-dash>=0.4.2",
    ]

visualizers special

Initialization of the Pipeline Run Visualizer.

pipeline_run_lineage_visualizer

Implementation of the pipeline run lineage visualizer.

PipelineRunLineageVisualizer (BasePipelineRunVisualizer)

Implementation of a lineage diagram via the dash and dash-cytoscape libraries.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BasePipelineRunVisualizer):
    """Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""

    ARTIFACT_PREFIX = "artifact_"
    STEP_PREFIX = "step_"
    STATUS_CLASS_MAPPING = {
        ExecutionStatus.CACHED: "green",
        ExecutionStatus.FAILED: "red",
        ExecutionStatus.RUNNING: "yellow",
        ExecutionStatus.COMPLETED: "blue",
    }

    def visualize(
        self,
        object: PipelineRunView,
        magic: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> dash.Dash:
        """Method to visualize pipeline runs via the Dash library.

        The layout puts every layer of the dag in a column.

        Args:
            object: The pipeline run to visualize.
            magic: If True, the visualization is rendered in a magic mode.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            The Dash application.
        """
        external_stylesheets = [
            dbc.themes.BOOTSTRAP,
            dbc.icons.BOOTSTRAP,
        ]
        if magic:
            if Environment.in_notebook:
                # Only import jupyter_dash in this case
                from jupyter_dash import JupyterDash  # noqa

                JupyterDash.infer_jupyter_proxy_config()

                app = JupyterDash(
                    __name__,
                    external_stylesheets=external_stylesheets,
                )
                mode = "inline"
            else:
                cli_utils.warning(
                    "Cannot set magic flag in non-notebook environments."
                )
        else:
            app = dash.Dash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = None
        nodes, edges, first_step_id = [], [], None
        first_step_id = None
        for step in object.steps:
            step_output_artifacts = list(step.outputs.values())
            execution_id = (
                step_output_artifacts[0].producer_step.id
                if step_output_artifacts
                else step.id
            )
            step_id = self.STEP_PREFIX + str(step.id)
            if first_step_id is None:
                first_step_id = step_id
            nodes.append(
                {
                    "data": {
                        "id": step_id,
                        "execution_id": execution_id,
                        "label": f"{execution_id} / {step.entrypoint_name}",
                        "entrypoint_name": step.entrypoint_name,  # redundant for consistency
                        "name": step.name,  # redundant for consistency
                        "type": "step",
                        "parameters": step.parameters,
                        "inputs": {k: v.uri for k, v in step.inputs.items()},
                        "outputs": {k: v.uri for k, v in step.outputs.items()},
                    },
                    "classes": self.STATUS_CLASS_MAPPING[step.status],
                }
            )

            for artifact_name, artifact in step.outputs.items():
                nodes.append(
                    {
                        "data": {
                            "id": self.ARTIFACT_PREFIX + str(artifact.id),
                            "execution_id": artifact.id,
                            "label": f"{artifact.id} / {artifact_name} ("
                            f"{artifact.data_type})",
                            "type": "artifact",
                            "name": artifact_name,
                            "is_cached": artifact.is_cached,
                            "artifact_type": artifact.type,
                            "artifact_data_type": artifact.data_type,
                            "parent_step_id": artifact.parent_step_id,
                            "producer_step_id": artifact.producer_step.id,
                            "uri": artifact.uri,
                        },
                        "classes": f"rectangle "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}",
                    }
                )
                edges.append(
                    {
                        "data": {
                            "source": self.STEP_PREFIX + str(step.id),
                            "target": self.ARTIFACT_PREFIX + str(artifact.id),
                        },
                        "classes": f"edge-arrow "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}"
                        + (" dashed" if artifact.is_cached else " solid"),
                    }
                )

            for artifact_name, artifact in step.inputs.items():
                edges.append(
                    {
                        "data": {
                            "source": self.ARTIFACT_PREFIX + str(artifact.id),
                            "target": self.STEP_PREFIX + str(step.id),
                        },
                        "classes": "edge-arrow "
                        + (
                            f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                            if artifact.is_cached
                            else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                        ),
                    }
                )

        app.layout = dbc.Row(
            [
                dbc.Container(f"Run: {object.name}", class_name="h1"),
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Row(
                                    [
                                        html.Span(
                                            [
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-circle-fill me-1"
                                                        ),
                                                        "Step",
                                                    ],
                                                    className="me-2",
                                                ),
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-square-fill me-1"
                                                        ),
                                                        "Artifact",
                                                    ],
                                                    className="me-4",
                                                ),
                                                dbc.Badge(
                                                    "Completed",
                                                    color=COLOR_BLUE,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Cached",
                                                    color=COLOR_GREEN,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Running",
                                                    color=COLOR_YELLOW,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Failed",
                                                    color=COLOR_RED,
                                                    className="me-1",
                                                ),
                                            ]
                                        ),
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        cyto.Cytoscape(
                                            id="cytoscape",
                                            layout={
                                                "name": "breadthfirst",
                                                "roots": f'[id = "{first_step_id}"]',
                                            },
                                            elements=edges + nodes,
                                            stylesheet=STYLESHEET,
                                            style={
                                                "width": "100%",
                                                "height": "800px",
                                            },
                                            zoom=1,
                                        )
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        dbc.Button(
                                            "Reset",
                                            id="bt-reset",
                                            color="primary",
                                            className="me-1",
                                        )
                                    ]
                                ),
                            ]
                        ),
                        dbc.Col(
                            [
                                dcc.Markdown(id="markdown-selected-node-data"),
                            ]
                        ),
                    ]
                ),
            ],
            className="p-5",
        )

        @app.callback(  # type: ignore[misc]
            Output("markdown-selected-node-data", "children"),
            Input("cytoscape", "selectedNodeData"),
        )
        def display_data(data_list: List[Dict[str, Any]]) -> str:
            """Callback for the text area below the graph.

            Args:
                data_list: The selected node data.

            Returns:
                str: The selected node data.
            """
            if data_list is None:
                return "Click on a node in the diagram."

            text = ""
            for data in data_list:
                text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
                if data["type"] == "artifact":
                    for item in [
                        "artifact_data_type",
                        "is_cached",
                        "producer_step_id",
                        "parent_step_id",
                        "uri",
                    ]:
                        text += f"**{item}**: {data[item]}" + "\n\n"
                elif data["type"] == "step":
                    text += "### Inputs:" + "\n\n"
                    for k, v in data["inputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Outputs:" + "\n\n"
                    for k, v in data["outputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Params:"
                    for k, v in data["parameters"].items():
                        text += f"**{k}**: {v}" + "\n\n"
            return text

        @app.callback(  # type: ignore[misc]
            [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
            [Input("bt-reset", "n_clicks")],
        )
        def reset_layout(
            n_clicks: int,
        ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
            """Resets the layout.

            Args:
                n_clicks: The number of clicks on the reset button.

            Returns:
                The zoom and the elements.
            """
            logger.debug(n_clicks, "clicked in reset button.")
            return [1, edges + nodes]

        if mode is not None:
            app.run_server(mode=mode)
        app.run_server()
        return app
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize pipeline runs via the Dash library.

The layout puts every layer of the dag in a column.

Parameters:

Name Type Description Default
object PipelineRunView

The pipeline run to visualize.

required
magic bool

If True, the visualization is rendered in a magic mode.

False
*args Any

Additional positional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Dash

The Dash application.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
    self,
    object: PipelineRunView,
    magic: bool = False,
    *args: Any,
    **kwargs: Any,
) -> dash.Dash:
    """Method to visualize pipeline runs via the Dash library.

    The layout puts every layer of the dag in a column.

    Args:
        object: The pipeline run to visualize.
        magic: If True, the visualization is rendered in a magic mode.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        The Dash application.
    """
    external_stylesheets = [
        dbc.themes.BOOTSTRAP,
        dbc.icons.BOOTSTRAP,
    ]
    if magic:
        if Environment.in_notebook:
            # Only import jupyter_dash in this case
            from jupyter_dash import JupyterDash  # noqa

            JupyterDash.infer_jupyter_proxy_config()

            app = JupyterDash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = "inline"
        else:
            cli_utils.warning(
                "Cannot set magic flag in non-notebook environments."
            )
    else:
        app = dash.Dash(
            __name__,
            external_stylesheets=external_stylesheets,
        )
        mode = None
    nodes, edges, first_step_id = [], [], None
    first_step_id = None
    for step in object.steps:
        step_output_artifacts = list(step.outputs.values())
        execution_id = (
            step_output_artifacts[0].producer_step.id
            if step_output_artifacts
            else step.id
        )
        step_id = self.STEP_PREFIX + str(step.id)
        if first_step_id is None:
            first_step_id = step_id
        nodes.append(
            {
                "data": {
                    "id": step_id,
                    "execution_id": execution_id,
                    "label": f"{execution_id} / {step.entrypoint_name}",
                    "entrypoint_name": step.entrypoint_name,  # redundant for consistency
                    "name": step.name,  # redundant for consistency
                    "type": "step",
                    "parameters": step.parameters,
                    "inputs": {k: v.uri for k, v in step.inputs.items()},
                    "outputs": {k: v.uri for k, v in step.outputs.items()},
                },
                "classes": self.STATUS_CLASS_MAPPING[step.status],
            }
        )

        for artifact_name, artifact in step.outputs.items():
            nodes.append(
                {
                    "data": {
                        "id": self.ARTIFACT_PREFIX + str(artifact.id),
                        "execution_id": artifact.id,
                        "label": f"{artifact.id} / {artifact_name} ("
                        f"{artifact.data_type})",
                        "type": "artifact",
                        "name": artifact_name,
                        "is_cached": artifact.is_cached,
                        "artifact_type": artifact.type,
                        "artifact_data_type": artifact.data_type,
                        "parent_step_id": artifact.parent_step_id,
                        "producer_step_id": artifact.producer_step.id,
                        "uri": artifact.uri,
                    },
                    "classes": f"rectangle "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}",
                }
            )
            edges.append(
                {
                    "data": {
                        "source": self.STEP_PREFIX + str(step.id),
                        "target": self.ARTIFACT_PREFIX + str(artifact.id),
                    },
                    "classes": f"edge-arrow "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}"
                    + (" dashed" if artifact.is_cached else " solid"),
                }
            )

        for artifact_name, artifact in step.inputs.items():
            edges.append(
                {
                    "data": {
                        "source": self.ARTIFACT_PREFIX + str(artifact.id),
                        "target": self.STEP_PREFIX + str(step.id),
                    },
                    "classes": "edge-arrow "
                    + (
                        f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                        if artifact.is_cached
                        else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                    ),
                }
            )

    app.layout = dbc.Row(
        [
            dbc.Container(f"Run: {object.name}", class_name="h1"),
            dbc.Row(
                [
                    dbc.Col(
                        [
                            dbc.Row(
                                [
                                    html.Span(
                                        [
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-circle-fill me-1"
                                                    ),
                                                    "Step",
                                                ],
                                                className="me-2",
                                            ),
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-square-fill me-1"
                                                    ),
                                                    "Artifact",
                                                ],
                                                className="me-4",
                                            ),
                                            dbc.Badge(
                                                "Completed",
                                                color=COLOR_BLUE,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Cached",
                                                color=COLOR_GREEN,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Running",
                                                color=COLOR_YELLOW,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Failed",
                                                color=COLOR_RED,
                                                className="me-1",
                                            ),
                                        ]
                                    ),
                                ]
                            ),
                            dbc.Row(
                                [
                                    cyto.Cytoscape(
                                        id="cytoscape",
                                        layout={
                                            "name": "breadthfirst",
                                            "roots": f'[id = "{first_step_id}"]',
                                        },
                                        elements=edges + nodes,
                                        stylesheet=STYLESHEET,
                                        style={
                                            "width": "100%",
                                            "height": "800px",
                                        },
                                        zoom=1,
                                    )
                                ]
                            ),
                            dbc.Row(
                                [
                                    dbc.Button(
                                        "Reset",
                                        id="bt-reset",
                                        color="primary",
                                        className="me-1",
                                    )
                                ]
                            ),
                        ]
                    ),
                    dbc.Col(
                        [
                            dcc.Markdown(id="markdown-selected-node-data"),
                        ]
                    ),
                ]
            ),
        ],
        className="p-5",
    )

    @app.callback(  # type: ignore[misc]
        Output("markdown-selected-node-data", "children"),
        Input("cytoscape", "selectedNodeData"),
    )
    def display_data(data_list: List[Dict[str, Any]]) -> str:
        """Callback for the text area below the graph.

        Args:
            data_list: The selected node data.

        Returns:
            str: The selected node data.
        """
        if data_list is None:
            return "Click on a node in the diagram."

        text = ""
        for data in data_list:
            text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
            if data["type"] == "artifact":
                for item in [
                    "artifact_data_type",
                    "is_cached",
                    "producer_step_id",
                    "parent_step_id",
                    "uri",
                ]:
                    text += f"**{item}**: {data[item]}" + "\n\n"
            elif data["type"] == "step":
                text += "### Inputs:" + "\n\n"
                for k, v in data["inputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Outputs:" + "\n\n"
                for k, v in data["outputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Params:"
                for k, v in data["parameters"].items():
                    text += f"**{k}**: {v}" + "\n\n"
        return text

    @app.callback(  # type: ignore[misc]
        [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
        [Input("bt-reset", "n_clicks")],
    )
    def reset_layout(
        n_clicks: int,
    ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
        """Resets the layout.

        Args:
            n_clicks: The number of clicks on the reset button.

        Returns:
            The zoom and the elements.
        """
        logger.debug(n_clicks, "clicked in reset button.")
        return [1, edges + nodes]

    if mode is not None:
        app.run_server(mode=mode)
    app.run_server()
    return app

deepchecks special

Deepchecks integration for ZenML.

The Deepchecks integration provides a way to validate your data in your pipelines. It includes a way to detect data anomalies and define checks to ensure quality of data.

The integration includes custom materializers to store Deepchecks SuiteResults and a visualizer to visualize the results in an easy way on a notebook and in your browser.

DeepchecksIntegration (Integration)

Definition of Deepchecks integration for ZenML.

Source code in zenml/integrations/deepchecks/__init__.py
class DeepchecksIntegration(Integration):
    """Definition of [Deepchecks](https://github.com/deepchecks/deepchecks) integration for ZenML."""

    NAME = DEEPCHECKS
    REQUIREMENTS = ["deepchecks[vision]==0.8.0", "torchvision==0.11.2"]

    @staticmethod
    def activate() -> None:
        """Activate the Deepchecks integration."""
        from zenml.integrations.deepchecks import materializers  # noqa
        from zenml.integrations.deepchecks import visualizers  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Deepchecks integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
                source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
                type=StackComponentType.DATA_VALIDATOR,
                integration=cls.NAME,
            ),
        ]
activate() staticmethod

Activate the Deepchecks integration.

Source code in zenml/integrations/deepchecks/__init__.py
@staticmethod
def activate() -> None:
    """Activate the Deepchecks integration."""
    from zenml.integrations.deepchecks import materializers  # noqa
    from zenml.integrations.deepchecks import visualizers  # noqa
flavors() classmethod

Declare the stack component flavors for the Deepchecks integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/deepchecks/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Deepchecks integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
            source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
            type=StackComponentType.DATA_VALIDATOR,
            integration=cls.NAME,
        ),
    ]

data_validators special

Initialization of the Deepchecks data validator for ZenML.

deepchecks_data_validator

Implementation of the Deepchecks data validator.

DeepchecksDataValidator (BaseDataValidator) pydantic-model

Deepchecks data validator stack component.

Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
class DeepchecksDataValidator(BaseDataValidator):
    """Deepchecks data validator stack component."""

    # Class Configuration
    FLAVOR: ClassVar[str] = DEEPCHECKS_DATA_VALIDATOR_FLAVOR
    NAME: ClassVar[str] = "Deepchecks"

    @staticmethod
    def _split_checks(
        check_list: Sequence[str],
    ) -> Tuple[Sequence[str], Sequence[str]]:
        """Split a list of check identifiers in two lists, one for tabular and one for computer vision checks.

        Args:
            check_list: A list of check identifiers.

        Returns:
            List of tabular check identifiers and list of computer vision
            check identifiers.
        """
        tabular_checks = list(
            filter(
                lambda check: DeepchecksValidationCheck.is_tabular_check(check),
                check_list,
            )
        )
        vision_checks = list(
            filter(
                lambda check: DeepchecksValidationCheck.is_vision_check(check),
                check_list,
            )
        )
        return tabular_checks, vision_checks

    # flake8: noqa: C901
    @classmethod
    def _create_and_run_check_suite(
        cls,
        check_enum: Type[DeepchecksValidationCheck],
        reference_dataset: Union[pd.DataFrame, DataLoader[Any]],
        comparison_dataset: Optional[
            Union[pd.DataFrame, DataLoader[Any]]
        ] = None,
        model: Optional[Union[ClassifierMixin, Module]] = None,
        check_list: Optional[Sequence[str]] = None,
        dataset_kwargs: Dict[str, Any] = {},
        check_kwargs: Dict[str, Dict[str, Any]] = {},
        run_kwargs: Dict[str, Any] = {},
    ) -> SuiteResult:
        """Create and run a Deepchecks check suite corresponding to the input parameters.

        This method contains generic logic common to all Deepchecks data
        validator methods that validates the input arguments and uses them to
        generate and run a Deepchecks check suite.

        Args:
            check_enum: ZenML enum type grouping together Deepchecks checks with
                the same characteristics. This is used to generate a default
                list of checks, if a custom list isn't provided via the
                `check_list` argument.
            reference_dataset: Primary (reference) dataset argument used during
                validation.
            comparison_dataset: Optional secondary (comparison) dataset argument
                used during comparison checks.
            model: Optional model argument used during validation.
            check_list: Optional list of ZenML Deepchecks check identifiers
                specifying the list of Deepchecks checks to be performed.
            dataset_kwargs: Additional keyword arguments to be passed to the
                Deepchecks tabular.Dataset or vision.VisionData constructor.
            check_kwargs: Additional keyword arguments to be passed to the
                Deepchecks check object constructors. Arguments are grouped for
                each check and indexed using the full check class name or
                check enum value as dictionary keys.
            run_kwargs: Additional keyword arguments to be passed to the
                Deepchecks Suite `run` method.

        Returns:
            Deepchecks SuiteResult object with the Suite run results.

        Raises:
            TypeError: If the datasets, model and check list arguments combine
                data types and/or checks from different categories (tabular and
                computer vision).
        """
        # Detect what type of check to perform (tabular or computer vision) from
        # the dataset/model datatypes and the check list. At the same time,
        # validate the combination of data types used for dataset and model
        # arguments and the check list.
        is_tabular = False
        is_vision = False
        for dataset in [reference_dataset, comparison_dataset]:
            if dataset is None:
                continue
            if isinstance(dataset, pd.DataFrame):
                is_tabular = True
            elif isinstance(dataset, DataLoader):
                is_vision = True
            else:
                raise TypeError(
                    f"Unsupported dataset data type found: {type(dataset)}. "
                    f"Supported data types are {str(pd.DataFrame)} for tabular "
                    f"data and {str(DataLoader)} for computer vision data."
                )

        if model:
            if isinstance(model, ClassifierMixin):
                is_tabular = True
            elif isinstance(model, Module):
                is_vision = True
            else:
                raise TypeError(
                    f"Unsupported model data type found: {type(model)}. "
                    f"Supported data types are {str(ClassifierMixin)} for "
                    f"tabular data and {str(Module)} for computer vision "
                    f"data."
                )

        if is_tabular and is_vision:
            raise TypeError(
                f"Tabular and computer vision data types used for datasets and "
                f"models cannot be mixed. They must all belong to the same "
                f"category. Supported data types for tabular data are "
                f"{str(pd.DataFrame)} for datasets and {str(ClassifierMixin)} "
                f"for models. Supported data types for computer vision data "
                f"are {str(pd.DataFrame)} for datasets and and {str(Module)} "
                f"for models."
            )

        if not check_list:
            # default to executing all the checks listed in the supplied
            # checks enum type if a custom check list is not supplied
            tabular_checks, vision_checks = cls._split_checks(
                check_enum.values()
            )
            if is_tabular:
                check_list = tabular_checks
                vision_checks = []
            else:
                check_list = vision_checks
                tabular_checks = []
        else:
            tabular_checks, vision_checks = cls._split_checks(check_list)

        if tabular_checks and vision_checks:
            raise TypeError(
                f"The check list cannot mix tabular checks "
                f"({tabular_checks}) and computer vision checks ("
                f"{vision_checks})."
            )

        if is_tabular and vision_checks:
            raise TypeError(
                f"Tabular data types used for datasets and models can only "
                f"be used with tabular validation checks. The following "
                f"computer vision checks included in the check list are "
                f"not valid: {vision_checks}."
            )

        if is_vision and tabular_checks:
            raise TypeError(
                f"Computer vision data types used for datasets and models "
                f"can only be used with computer vision validation checks. "
                f"The following tabular checks included in the check list "
                f"are not valid: {tabular_checks}."
            )

        check_classes = map(
            lambda check: (
                check,
                check_enum.get_check_class(check),
            ),
            check_list,
        )

        # use the pipeline name and the step name to generate a unique suite
        # name
        try:
            # get pipeline name and step name
            step_env = cast(
                StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
            )
            suite_name = f"{step_env.pipeline_name}_{step_env.step_name}"
        except KeyError:
            # if not running inside a pipeline step, use random values
            suite_name = f"suite_{random_str(5)}"

        if is_tabular:
            dataset_class = TabularData
            suite_class = TabularSuite
            full_suite = full_tabular_suite()
        else:
            dataset_class = VisionData
            suite_class = VisionSuite
            full_suite = full_vision_suite()

        train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
        test_dataset = None
        if comparison_dataset is not None:
            test_dataset = dataset_class(comparison_dataset, **dataset_kwargs)
        suite = suite_class(name=suite_name)

        # Some Deepchecks checks require a minimum configuration such as
        # conditions to be configured (see https://docs.deepchecks.com/stable/user-guide/general/customizations/examples/plot_configure_check_conditions.html#sphx-glr-user-guide-general-customizations-examples-plot-configure-check-conditions-py)
        # for their execution to have meaning. For checks that don't have
        # custom configuration attributes explicitly specified in the
        # `check_kwargs` input parameter, we use the default check
        # instances extracted from the full suite shipped with Deepchecks.
        default_checks = {
            check.__class__: check for check in full_suite.checks.values()
        }
        for check_name, check_class in check_classes:
            extra_kwargs = check_kwargs.get(check_name, {})
            default_check = default_checks.get(check_class)
            check: BaseCheck
            if extra_kwargs or not default_check:
                check = check_class(**check_kwargs)
            else:
                check = default_check

            # extract the condition kwargs from the check kwargs
            for arg_name, condition_kwargs in extra_kwargs.items():
                if not arg_name.startswith("condition_") or not isinstance(
                    condition_kwargs, dict
                ):
                    continue
                condition_method = getattr(check, f"add_{arg_name}", None)
                if not condition_method or not callable(condition_method):
                    logger.warning(
                        f"Deepchecks check type {check.__class__} has no "
                        f"condition named {arg_name}. Ignoring the check "
                        f"argument."
                    )
                    continue
                condition_method(**condition_kwargs)

            suite.add(check)
        return suite.run(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            model=model,
            **run_kwargs,
        )

    def data_validation(
        self,
        dataset: Union[pd.DataFrame, DataLoader[Any]],
        comparison_dataset: Optional[Any] = None,
        check_list: Optional[Sequence[str]] = None,
        dataset_kwargs: Dict[str, Any] = {},
        check_kwargs: Dict[str, Dict[str, Any]] = {},
        run_kwargs: Dict[str, Any] = {},
        **kwargs: Any,
    ) -> SuiteResult:
        """Run one or more Deepchecks data validation checks on a dataset.

        Call this method to analyze and identify potential integrity problems
        with a single dataset (e.g. missing values, conflicting labels, mixed
        data types etc.) and dataset comparison checks (e.g. data drift
        checks). Dataset comparison checks require that a second dataset be
        supplied via the `comparison_dataset` argument.

        The `check_list` argument may be used to specify a custom set of
        Deepchecks data integrity checks to perform, identified by
        `DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
        values. If omitted:

        * if the `comparison_dataset` is omitted, a suite with all available
        data integrity checks will be performed on the input data. See
        `DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
        checks that are compatible with this method.

        * if the `comparison_dataset` is supplied, a suite with all
        available data drift checks will be performed on the input
        data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
        builtin checks that are compatible with this method.

        Args:
            dataset: Target dataset to be validated.
            comparison_dataset: Optional second dataset to be used for data
                comparison checks (e.g data drift checks).
            check_list: Optional list of ZenML Deepchecks check identifiers
                specifying the data validation checks to be performed.
                `DeepchecksDataIntegrityCheck` enum values should be used for
                single data validation checks and `DeepchecksDataDriftCheck`
                enum values for data comparison checks. If not supplied, the
                entire set of checks applicable to the input dataset(s)
                will be performed.
            dataset_kwargs: Additional keyword arguments to be passed to the
                Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
            check_kwargs: Additional keyword arguments to be passed to the
                Deepchecks check object constructors. Arguments are grouped for
                each check and indexed using the full check class name or
                check enum value as dictionary keys.
            run_kwargs: Additional keyword arguments to be passed to the
                Deepchecks Suite `run` method.
            kwargs: Additional keyword arguments (unused).

        Returns:
            A Deepchecks SuiteResult with the results of the validation.
        """
        check_enum: Type[DeepchecksValidationCheck]
        if comparison_dataset is None:
            check_enum = DeepchecksDataIntegrityCheck
        else:
            check_enum = DeepchecksDataDriftCheck

        return self._create_and_run_check_suite(
            check_enum=check_enum,
            reference_dataset=dataset,
            comparison_dataset=comparison_dataset,
            check_list=check_list,
            dataset_kwargs=dataset_kwargs,
            check_kwargs=check_kwargs,
            run_kwargs=run_kwargs,
        )

    def model_validation(
        self,
        dataset: Union[pd.DataFrame, DataLoader[Any]],
        model: Union[ClassifierMixin, Module],
        comparison_dataset: Optional[Any] = None,
        check_list: Optional[Sequence[str]] = None,
        dataset_kwargs: Dict[str, Any] = {},
        check_kwargs: Dict[str, Dict[str, Any]] = {},
        run_kwargs: Dict[str, Any] = {},
        **kwargs: Any,
    ) -> Any:
        """Run one or more Deepchecks model validation checks.

        Call this method to perform model validation checks (e.g. confusion
        matrix validation, performance reports, model error analyses, etc).
        A second dataset is required for model performance comparison tests
        (i.e. tests that identify changes in a model behavior by comparing how
        it performs on two different datasets).

        The `check_list` argument may be used to specify a custom set of
        Deepchecks model validation checks to perform, identified by
        `DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
        values. If omitted:

            * if the `comparison_dataset` is omitted, a suite with all available
            model validation checks will be performed on the input data. See
            `DeepchecksModelValidationCheck` for a list of Deepchecks builtin
            checks that are compatible with this method.

            * if the `comparison_dataset` is supplied, a suite with all
            available model comparison checks will be performed on the input
            data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
            builtin checks that are compatible with this method.

        Args:
            dataset: Target dataset to be validated.
            model: Target model to be validated.
            comparison_dataset: Optional second dataset to be used for model
                comparison checks.
            check_list: Optional list of ZenML Deepchecks check identifiers
                specifying the model validation checks to be performed.
                `DeepchecksModelValidationCheck` enum values should be used for
                model validation checks and `DeepchecksModelDriftCheck` enum
                values for model comparison checks. If not supplied, the
                entire set of checks applicable to the input dataset(s)
                will be performed.
            dataset_kwargs: Additional keyword arguments to be passed to the
                Deepchecks tabular.Dataset or vision.VisionData constructor.
            check_kwargs: Additional keyword arguments to be passed to the
                Deepchecks check object constructors. Arguments are grouped for
                each check and indexed using the full check class name or
                check enum value as dictionary keys.
            run_kwargs: Additional keyword arguments to be passed to the
                Deepchecks Suite `run` method.
            kwargs: Additional keyword arguments (unused).

        Returns:
            A Deepchecks SuiteResult with the results of the validation.
        """
        check_enum: Type[DeepchecksValidationCheck]
        if comparison_dataset is None:
            check_enum = DeepchecksModelValidationCheck
        else:
            check_enum = DeepchecksModelDriftCheck

        return self._create_and_run_check_suite(
            check_enum=check_enum,
            reference_dataset=dataset,
            comparison_dataset=comparison_dataset,
            model=model,
            check_list=check_list,
            dataset_kwargs=dataset_kwargs,
            check_kwargs=check_kwargs,
            run_kwargs=run_kwargs,
        )
data_validation(self, dataset, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)

Run one or more Deepchecks data validation checks on a dataset.

Call this method to analyze and identify potential integrity problems with a single dataset (e.g. missing values, conflicting labels, mixed data types etc.) and dataset comparison checks (e.g. data drift checks). Dataset comparison checks require that a second dataset be supplied via the comparison_dataset argument.

The check_list argument may be used to specify a custom set of Deepchecks data integrity checks to perform, identified by DeepchecksDataIntegrityCheck and DeepchecksDataDriftCheck enum values. If omitted:

  • if the comparison_dataset is omitted, a suite with all available data integrity checks will be performed on the input data. See DeepchecksDataIntegrityCheck for a list of Deepchecks builtin checks that are compatible with this method.

  • if the comparison_dataset is supplied, a suite with all available data drift checks will be performed on the input data. See DeepchecksDataDriftCheck for a list of Deepchecks builtin checks that are compatible with this method.

Parameters:

Name Type Description Default
dataset Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]]

Target dataset to be validated.

required
comparison_dataset Optional[Any]

Optional second dataset to be used for data comparison checks (e.g data drift checks).

None
check_list Optional[Sequence[str]]

Optional list of ZenML Deepchecks check identifiers specifying the data validation checks to be performed. DeepchecksDataIntegrityCheck enum values should be used for single data validation checks and DeepchecksDataDriftCheck enum values for data comparison checks. If not supplied, the entire set of checks applicable to the input dataset(s) will be performed.

None
dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

{}
check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

{}
run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

{}
kwargs Any

Additional keyword arguments (unused).

{}

Returns:

Type Description
SuiteResult

A Deepchecks SuiteResult with the results of the validation.

Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def data_validation(
    self,
    dataset: Union[pd.DataFrame, DataLoader[Any]],
    comparison_dataset: Optional[Any] = None,
    check_list: Optional[Sequence[str]] = None,
    dataset_kwargs: Dict[str, Any] = {},
    check_kwargs: Dict[str, Dict[str, Any]] = {},
    run_kwargs: Dict[str, Any] = {},
    **kwargs: Any,
) -> SuiteResult:
    """Run one or more Deepchecks data validation checks on a dataset.

    Call this method to analyze and identify potential integrity problems
    with a single dataset (e.g. missing values, conflicting labels, mixed
    data types etc.) and dataset comparison checks (e.g. data drift
    checks). Dataset comparison checks require that a second dataset be
    supplied via the `comparison_dataset` argument.

    The `check_list` argument may be used to specify a custom set of
    Deepchecks data integrity checks to perform, identified by
    `DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
    values. If omitted:

    * if the `comparison_dataset` is omitted, a suite with all available
    data integrity checks will be performed on the input data. See
    `DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
    checks that are compatible with this method.

    * if the `comparison_dataset` is supplied, a suite with all
    available data drift checks will be performed on the input
    data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
    builtin checks that are compatible with this method.

    Args:
        dataset: Target dataset to be validated.
        comparison_dataset: Optional second dataset to be used for data
            comparison checks (e.g data drift checks).
        check_list: Optional list of ZenML Deepchecks check identifiers
            specifying the data validation checks to be performed.
            `DeepchecksDataIntegrityCheck` enum values should be used for
            single data validation checks and `DeepchecksDataDriftCheck`
            enum values for data comparison checks. If not supplied, the
            entire set of checks applicable to the input dataset(s)
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
        kwargs: Additional keyword arguments (unused).

    Returns:
        A Deepchecks SuiteResult with the results of the validation.
    """
    check_enum: Type[DeepchecksValidationCheck]
    if comparison_dataset is None:
        check_enum = DeepchecksDataIntegrityCheck
    else:
        check_enum = DeepchecksDataDriftCheck

    return self._create_and_run_check_suite(
        check_enum=check_enum,
        reference_dataset=dataset,
        comparison_dataset=comparison_dataset,
        check_list=check_list,
        dataset_kwargs=dataset_kwargs,
        check_kwargs=check_kwargs,
        run_kwargs=run_kwargs,
    )
model_validation(self, dataset, model, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)

Run one or more Deepchecks model validation checks.

Call this method to perform model validation checks (e.g. confusion matrix validation, performance reports, model error analyses, etc). A second dataset is required for model performance comparison tests (i.e. tests that identify changes in a model behavior by comparing how it performs on two different datasets).

The check_list argument may be used to specify a custom set of Deepchecks model validation checks to perform, identified by DeepchecksModelValidationCheck and DeepchecksModelDriftCheck enum values. If omitted:

* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.

* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.

Parameters:

Name Type Description Default
dataset Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]]

Target dataset to be validated.

required
model Union[sklearn.base.ClassifierMixin, torch.nn.modules.module.Module]

Target model to be validated.

required
comparison_dataset Optional[Any]

Optional second dataset to be used for model comparison checks.

None
check_list Optional[Sequence[str]]

Optional list of ZenML Deepchecks check identifiers specifying the model validation checks to be performed. DeepchecksModelValidationCheck enum values should be used for model validation checks and DeepchecksModelDriftCheck enum values for model comparison checks. If not supplied, the entire set of checks applicable to the input dataset(s) will be performed.

None
dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

{}
check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

{}
run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

{}
kwargs Any

Additional keyword arguments (unused).

{}

Returns:

Type Description
Any

A Deepchecks SuiteResult with the results of the validation.

Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def model_validation(
    self,
    dataset: Union[pd.DataFrame, DataLoader[Any]],
    model: Union[ClassifierMixin, Module],
    comparison_dataset: Optional[Any] = None,
    check_list: Optional[Sequence[str]] = None,
    dataset_kwargs: Dict[str, Any] = {},
    check_kwargs: Dict[str, Dict[str, Any]] = {},
    run_kwargs: Dict[str, Any] = {},
    **kwargs: Any,
) -> Any:
    """Run one or more Deepchecks model validation checks.

    Call this method to perform model validation checks (e.g. confusion
    matrix validation, performance reports, model error analyses, etc).
    A second dataset is required for model performance comparison tests
    (i.e. tests that identify changes in a model behavior by comparing how
    it performs on two different datasets).

    The `check_list` argument may be used to specify a custom set of
    Deepchecks model validation checks to perform, identified by
    `DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
    values. If omitted:

        * if the `comparison_dataset` is omitted, a suite with all available
        model validation checks will be performed on the input data. See
        `DeepchecksModelValidationCheck` for a list of Deepchecks builtin
        checks that are compatible with this method.

        * if the `comparison_dataset` is supplied, a suite with all
        available model comparison checks will be performed on the input
        data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
        builtin checks that are compatible with this method.

    Args:
        dataset: Target dataset to be validated.
        model: Target model to be validated.
        comparison_dataset: Optional second dataset to be used for model
            comparison checks.
        check_list: Optional list of ZenML Deepchecks check identifiers
            specifying the model validation checks to be performed.
            `DeepchecksModelValidationCheck` enum values should be used for
            model validation checks and `DeepchecksModelDriftCheck` enum
            values for model comparison checks. If not supplied, the
            entire set of checks applicable to the input dataset(s)
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks tabular.Dataset or vision.VisionData constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
        kwargs: Additional keyword arguments (unused).

    Returns:
        A Deepchecks SuiteResult with the results of the validation.
    """
    check_enum: Type[DeepchecksValidationCheck]
    if comparison_dataset is None:
        check_enum = DeepchecksModelValidationCheck
    else:
        check_enum = DeepchecksModelDriftCheck

    return self._create_and_run_check_suite(
        check_enum=check_enum,
        reference_dataset=dataset,
        comparison_dataset=comparison_dataset,
        model=model,
        check_list=check_list,
        dataset_kwargs=dataset_kwargs,
        check_kwargs=check_kwargs,
        run_kwargs=run_kwargs,
    )

materializers special

Deepchecks materializers.

deepchecks_dataset_materializer

Implementation of Deepchecks dataset materializer.

DeepchecksDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from Deepchecks dataset.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
class DeepchecksDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from Deepchecks dataset."""

    ASSOCIATED_TYPES = (Dataset,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Dataset:
        """Reads pandas dataframes and creates deepchecks.Dataset from it.

        Args:
            data_type: The type of the data to read.

        Returns:
            A Deepchecks Dataset.
        """
        super().handle_input(data_type)

        # Outsource to pandas
        pandas_materializer = PandasMaterializer(self.artifact)
        df = pandas_materializer.handle_input(data_type)

        # Recreate from pandas dataframe
        return Dataset(df)

    def handle_return(self, df: Dataset) -> None:
        """Serializes pandas dataframe within a Dataset object.

        Args:
            df: A deepchecks.Dataset object.
        """
        super().handle_return(df)

        # Outsource to pandas
        pandas_materializer = PandasMaterializer(self.artifact)
        pandas_materializer.handle_return(df.data)
handle_input(self, data_type)

Reads pandas dataframes and creates deepchecks.Dataset from it.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Dataset

A Deepchecks Dataset.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
    """Reads pandas dataframes and creates deepchecks.Dataset from it.

    Args:
        data_type: The type of the data to read.

    Returns:
        A Deepchecks Dataset.
    """
    super().handle_input(data_type)

    # Outsource to pandas
    pandas_materializer = PandasMaterializer(self.artifact)
    df = pandas_materializer.handle_input(data_type)

    # Recreate from pandas dataframe
    return Dataset(df)
handle_return(self, df)

Serializes pandas dataframe within a Dataset object.

Parameters:

Name Type Description Default
df Dataset

A deepchecks.Dataset object.

required
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_return(self, df: Dataset) -> None:
    """Serializes pandas dataframe within a Dataset object.

    Args:
        df: A deepchecks.Dataset object.
    """
    super().handle_return(df)

    # Outsource to pandas
    pandas_materializer = PandasMaterializer(self.artifact)
    pandas_materializer.handle_return(df.data)
deepchecks_results_materializer

Implementation of Deepchecks suite results materializer.

DeepchecksResultMaterializer (BaseMaterializer)

Materializer to read data to and from CheckResult and SuiteResult objects.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
class DeepchecksResultMaterializer(BaseMaterializer):
    """Materializer to read data to and from CheckResult and SuiteResult objects."""

    ASSOCIATED_TYPES = (
        CheckResult,
        SuiteResult,
    )
    ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)

    def handle_input(
        self, data_type: Type[Any]
    ) -> Union[CheckResult, SuiteResult]:
        """Reads a Deepchecks check or suite result from a serialized JSON file.

        Args:
            data_type: The type of the data to read.

        Returns:
            A Deepchecks CheckResult or SuiteResult.

        Raises:
            RuntimeError: if the input data type is not supported.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

        json_res = io_utils.read_file_contents_as_string(filepath)
        if data_type == SuiteResult:
            res = SuiteResult.from_json(json_res)
        elif data_type == CheckResult:
            res = CheckResult.from_json(json_res)
        else:
            raise RuntimeError(f"Unknown data type: {data_type}")
        return res

    def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
        """Creates a JSON serialization for a CheckResult or SuiteResult.

        Args:
            result: A Deepchecks CheckResult or SuiteResult.
        """
        super().handle_return(result)

        filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

        serialized_json = result.to_json(True)
        io_utils.write_file_contents_as_string(filepath, serialized_json)
handle_input(self, data_type)

Reads a Deepchecks check or suite result from a serialized JSON file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult]

A Deepchecks CheckResult or SuiteResult.

Exceptions:

Type Description
RuntimeError

if the input data type is not supported.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_input(
    self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
    """Reads a Deepchecks check or suite result from a serialized JSON file.

    Args:
        data_type: The type of the data to read.

    Returns:
        A Deepchecks CheckResult or SuiteResult.

    Raises:
        RuntimeError: if the input data type is not supported.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

    json_res = io_utils.read_file_contents_as_string(filepath)
    if data_type == SuiteResult:
        res = SuiteResult.from_json(json_res)
    elif data_type == CheckResult:
        res = CheckResult.from_json(json_res)
    else:
        raise RuntimeError(f"Unknown data type: {data_type}")
    return res
handle_return(self, result)

Creates a JSON serialization for a CheckResult or SuiteResult.

Parameters:

Name Type Description Default
result Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult]

A Deepchecks CheckResult or SuiteResult.

required
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
    """Creates a JSON serialization for a CheckResult or SuiteResult.

    Args:
        result: A Deepchecks CheckResult or SuiteResult.
    """
    super().handle_return(result)

    filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

    serialized_json = result.to_json(True)
    io_utils.write_file_contents_as_string(filepath, serialized_json)

steps special

Initialization of the Deepchecks Standard Steps.

deepchecks_data_drift

Implementation of the Deepchecks data drift validation step.

DeepchecksDataDriftCheckStep (BaseStep)

Deepchecks data drift validator step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStep(BaseStep):
    """Deepchecks data drift validator step."""

    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: pd.DataFrame,
        target_dataset: pd.DataFrame,
        config: DeepchecksDataDriftCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks data drift validator step.

        Args:
            reference_dataset: Reference dataset for the data drift check.
            target_dataset: Target dataset to be used for the data drift check.
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.data_validation(
            dataset=reference_dataset,
            comparison_dataset=target_dataset,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks data drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]]

Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataDriftCheck identifiers
            specifying the subset of Deepchecks data drift checks to be
            performed. If not supplied, the entire set of data drift checks will
            be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, config)

Main entrypoint for the Deepchecks data drift validator step.

Parameters:

Name Type Description Default
reference_dataset DataFrame

Reference dataset for the data drift check.

required
target_dataset DataFrame

Target dataset to be used for the data drift check.

required
config DeepchecksDataDriftCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: pd.DataFrame,
    target_dataset: pd.DataFrame,
    config: DeepchecksDataDriftCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks data drift validator step.

    Args:
        reference_dataset: Reference dataset for the data drift check.
        target_dataset: Target dataset to be used for the data drift check.
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.data_validation(
        dataset=reference_dataset,
        comparison_dataset=target_dataset,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksDataDriftCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks data drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]]

Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataDriftCheck identifiers
            specifying the subset of Deepchecks data drift checks to be
            performed. If not supplied, the entire set of data drift checks will
            be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_drift_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.

The returned DeepchecksDataDriftCheckStep can be used in a pipeline to run data drift checks on two input pd.DataFrame and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksDataDriftCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksDataDriftCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def deepchecks_data_drift_check_step(
    step_name: str,
    config: DeepchecksDataDriftCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.

    The returned DeepchecksDataDriftCheckStep can be used in a pipeline to
    run data drift checks on two input pd.DataFrame and return the results
    as a Deepchecks SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksDataDriftCheckStep step instance
    """
    return clone_step(DeepchecksDataDriftCheckStep, step_name)(config=config)
deepchecks_data_integrity

Implementation of the Deepchecks data integrity validation step.

DeepchecksDataIntegrityCheckStep (BaseStep)

Deepchecks data integrity validator step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStep(BaseStep):
    """Deepchecks data integrity validator step."""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: DeepchecksDataIntegrityCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks data integrity validator step.

        Args:
            dataset: a Pandas DataFrame to validate
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.data_validation(
            dataset=dataset,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks data integrity validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]]

Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data integrity validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
            specifying the subset of Deepchecks data integrity checks to be
            performed. If not supplied, the entire set of data integrity checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, config)

Main entrypoint for the Deepchecks data integrity validator step.

Parameters:

Name Type Description Default
dataset DataFrame

a Pandas DataFrame to validate

required
config DeepchecksDataIntegrityCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: DeepchecksDataIntegrityCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks data integrity validator step.

    Args:
        dataset: a Pandas DataFrame to validate
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.data_validation(
        dataset=dataset,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksDataIntegrityCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks data integrity validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]]

Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data integrity validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
            specifying the subset of Deepchecks data integrity checks to be
            performed. If not supplied, the entire set of data integrity checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_integrity_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.

The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to run data integrity checks on an input pd.DataFrame and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksDataIntegrityCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksDataIntegrityCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def deepchecks_data_integrity_check_step(
    step_name: str,
    config: DeepchecksDataIntegrityCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.

    The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to
    run data integrity checks on an input pd.DataFrame and return the results
    as a Deepchecks SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksDataIntegrityCheckStep step instance
    """
    return clone_step(DeepchecksDataIntegrityCheckStep, step_name)(
        config=config
    )
deepchecks_model_drift

Implementation of the Deepchecks model drift validation step.

DeepchecksModelDriftCheckStep (BaseStep)

Deepchecks model drift step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStep(BaseStep):
    """Deepchecks model drift step."""

    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: pd.DataFrame,
        target_dataset: pd.DataFrame,
        model: ClassifierMixin,
        config: DeepchecksModelDriftCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks model drift step.

        Args:
            reference_dataset: Reference dataset for the model drift check.
            target_dataset: Target dataset to be used for the model drift check.
            model: a scikit-learn model to validate
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.model_validation(
            dataset=reference_dataset,
            comparison_dataset=target_dataset,
            model=model,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks model drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]]

Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelDriftCheck identifiers
            specifying the subset of Deepchecks model drift checks to be
            performed. If not supplied, the entire set of model drift checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, model, config)

Main entrypoint for the Deepchecks model drift step.

Parameters:

Name Type Description Default
reference_dataset DataFrame

Reference dataset for the model drift check.

required
target_dataset DataFrame

Target dataset to be used for the model drift check.

required
model ClassifierMixin

a scikit-learn model to validate

required
config DeepchecksModelDriftCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: pd.DataFrame,
    target_dataset: pd.DataFrame,
    model: ClassifierMixin,
    config: DeepchecksModelDriftCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks model drift step.

    Args:
        reference_dataset: Reference dataset for the model drift check.
        target_dataset: Target dataset to be used for the model drift check.
        model: a scikit-learn model to validate
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.model_validation(
        dataset=reference_dataset,
        comparison_dataset=target_dataset,
        model=model,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksModelDriftCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks model drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]]

Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelDriftCheck identifiers
            specifying the subset of Deepchecks model drift checks to be
            performed. If not supplied, the entire set of model drift checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_drift_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.

The returned DeepchecksModelDriftCheckStep can be used in a pipeline to run model drift checks on two input pd.DataFrame datasets and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksModelDriftCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksModelDriftCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def deepchecks_model_drift_check_step(
    step_name: str,
    config: DeepchecksModelDriftCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.

    The returned DeepchecksModelDriftCheckStep can be used in a pipeline to
    run model drift checks on two input pd.DataFrame datasets and an input
    scikit-learn ClassifierMixin model and return the results as a Deepchecks
    SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksModelDriftCheckStep step instance
    """
    return clone_step(DeepchecksModelDriftCheckStep, step_name)(config=config)
deepchecks_model_validation

Implementation of the Deepchecks model validation validation step.

DeepchecksModelValidationCheckStep (BaseStep)

Deepchecks model validation step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStep(BaseStep):
    """Deepchecks model validation step."""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        model: ClassifierMixin,
        config: DeepchecksModelValidationCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks model validation step.

        Args:
            dataset: a Pandas DataFrame to use for the validation
            model: a scikit-learn model to validate
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.model_validation(
            dataset=dataset,
            model=model,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks model validation validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]]

Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model validation validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelValidationCheck identifiers
            specifying the subset of Deepchecks model validation checks to be
            performed. If not supplied, the entire set of model validation checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, model, config)

Main entrypoint for the Deepchecks model validation step.

Parameters:

Name Type Description Default
dataset DataFrame

a Pandas DataFrame to use for the validation

required
model ClassifierMixin

a scikit-learn model to validate

required
config DeepchecksModelValidationCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    model: ClassifierMixin,
    config: DeepchecksModelValidationCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks model validation step.

    Args:
        dataset: a Pandas DataFrame to use for the validation
        model: a scikit-learn model to validate
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.model_validation(
        dataset=dataset,
        model=model,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksModelValidationCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks model validation validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]]

Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model validation validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelValidationCheck identifiers
            specifying the subset of Deepchecks model validation checks to be
            performed. If not supplied, the entire set of model validation checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_validation_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.

The returned DeepchecksModelValidationCheckStep can be used in a pipeline to run model validation checks on an input pd.DataFrame dataset and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksModelValidationCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksModelValidationCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def deepchecks_model_validation_check_step(
    step_name: str,
    config: DeepchecksModelValidationCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.

    The returned DeepchecksModelValidationCheckStep can be used in a pipeline to
    run model validation checks on an input pd.DataFrame dataset and an input
    scikit-learn ClassifierMixin model and return the results as a Deepchecks
    SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksModelValidationCheckStep step instance
    """
    return clone_step(DeepchecksModelValidationCheckStep, step_name)(
        config=config
    )

validation_checks

Definition of the Deepchecks validation check types.

DeepchecksDataDriftCheck (DeepchecksValidationCheck)

Categories of Deepchecks data drift checks.

This list reflects the set of train-test validation checks provided by Deepchecks:

All these checks inherit from deepchecks.tabular.TrainTestCheck or deepchecks.vision.TrainTestCheck and require two datasets as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks data drift checks.

    This list reflects the set of train-test validation checks provided by
    Deepchecks:

      * [for tabular data](https://docs.deepchecks.com/stable/checks_gallery/tabular.html#train-test-validation)
      * [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#train-test-validation)

    All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
    `deepchecks.vision.TrainTestCheck` and require two datasets as input.
    """

    TABULAR_CATEGORY_MISMATCH_TRAIN_TEST = resolve_class(
        tabular_checks.CategoryMismatchTrainTest
    )
    TABULAR_DATASET_SIZE_COMPARISON = resolve_class(
        tabular_checks.DatasetsSizeComparison
    )
    TABULAR_DATE_TRAIN_TEST_LEAKAGE_DUPLICATES = resolve_class(
        tabular_checks.DateTrainTestLeakageDuplicates
    )
    TABULAR_DATE_TRAIN_TEST_LEAKAGE_OVERLAP = resolve_class(
        tabular_checks.DateTrainTestLeakageOverlap
    )
    TABULAR_DOMINANT_FREQUENCY_CHANGE = resolve_class(
        tabular_checks.DominantFrequencyChange
    )
    TABULAR_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
        tabular_checks.FeatureLabelCorrelationChange
    )
    TABULAR_INDEX_LEAKAGE = resolve_class(tabular_checks.IndexTrainTestLeakage)
    TABULAR_NEW_LABEL_TRAIN_TEST = resolve_class(
        tabular_checks.NewLabelTrainTest
    )
    TABULAR_STRING_MISMATCH_COMPARISON = resolve_class(
        tabular_checks.StringMismatchComparison
    )
    TABULAR_TRAIN_TEST_FEATURE_DRIFT = resolve_class(
        tabular_checks.TrainTestFeatureDrift
    )
    TABULAR_TRAIN_TEST_LABEL_DRIFT = resolve_class(
        tabular_checks.TrainTestLabelDrift
    )
    TABULAR_TRAIN_TEST_SAMPLES_MIX = resolve_class(
        tabular_checks.TrainTestSamplesMix
    )
    TABULAR_WHOLE_DATASET_DRIFT = resolve_class(
        tabular_checks.WholeDatasetDrift
    )

    VISION_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
        vision_checks.FeatureLabelCorrelationChange
    )
    VISION_HEATMAP_COMPARISON = resolve_class(vision_checks.HeatmapComparison)
    VISION_IMAGE_DATASET_DRIFT = resolve_class(vision_checks.ImageDatasetDrift)
    VISION_IMAGE_PROPERTY_DRIFT = resolve_class(
        vision_checks.ImagePropertyDrift
    )
    VISION_NEW_LABELS = resolve_class(vision_checks.NewLabels)
    VISION_SIMILAR_IMAGE_LEAKAGE = resolve_class(
        vision_checks.SimilarImageLeakage
    )
    VISION_TRAIN_TEST_LABEL_DRIFT = resolve_class(
        vision_checks.TrainTestLabelDrift
    )
DeepchecksDataIntegrityCheck (DeepchecksValidationCheck)

Categories of Deepchecks data integrity checks.

This list reflects the set of data integrity checks provided by Deepchecks:

All these checks inherit from deepchecks.tabular.SingleDatasetCheck or deepchecks.vision.SingleDatasetCheck and require a single dataset as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks data integrity checks.

    This list reflects the set of data integrity checks provided by Deepchecks:

      * [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#data-integrity)
      * [for computer vision](https://docs.deepchecks.com/en/stable/checks_gallery/vision.html#data-integrity)

    All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
    `deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
    """

    TABULAR_COLUMNS_INFO = resolve_class(tabular_checks.ColumnsInfo)
    TABULAR_CONFLICTING_LABELS = resolve_class(tabular_checks.ConflictingLabels)
    TABULAR_DATA_DUPLICATES = resolve_class(tabular_checks.DataDuplicates)
    TABULAR_FEATURE_FEATURE_CORRELATION = resolve_class(
        FeatureFeatureCorrelation
    )
    TABULAR_FEATURE_LABEL_CORRELATION = resolve_class(
        tabular_checks.FeatureLabelCorrelation
    )
    TABULAR_IDENTIFIER_LEAKAGE = resolve_class(tabular_checks.IdentifierLeakage)
    TABULAR_IS_SINGLE_VALUE = resolve_class(tabular_checks.IsSingleValue)
    TABULAR_MIXED_DATA_TYPES = resolve_class(tabular_checks.MixedDataTypes)
    TABULAR_MIXED_NULLS = resolve_class(tabular_checks.MixedNulls)
    TABULAR_OUTLIER_SAMPLE_DETECTION = resolve_class(
        tabular_checks.OutlierSampleDetection
    )
    TABULAR_SPECIAL_CHARS = resolve_class(tabular_checks.SpecialCharacters)
    TABULAR_STRING_LENGTH_OUT_OF_BOUNDS = resolve_class(
        tabular_checks.StringLengthOutOfBounds
    )
    TABULAR_STRING_MISMATCH = resolve_class(tabular_checks.StringMismatch)

    VISION_IMAGE_PROPERTY_OUTLIERS = resolve_class(
        vision_checks.ImagePropertyOutliers
    )
    VISION_LABEL_PROPERTY_OUTLIERS = resolve_class(
        vision_checks.LabelPropertyOutliers
    )
DeepchecksModelDriftCheck (DeepchecksValidationCheck)

Categories of Deepchecks model drift checks.

This list includes a subset of the model evaluation checks provided by Deepchecks that require two datasets and a mandatory model as input:

All these checks inherit from deepchecks.tabular.TrainTestCheck or deepchecks.vision.TrainTestCheck and require two datasets and a mandatory model as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks model drift checks.

    This list includes a subset of the model evaluation checks provided by
    Deepchecks that require two datasets and a mandatory model as input:

      * [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
      * [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)

    All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
    `deepchecks.vision.TrainTestCheck` and require two datasets and a mandatory
    model as input.
    """

    TABULAR_BOOSTING_OVERFIT = resolve_class(tabular_checks.BoostingOverfit)
    TABULAR_MODEL_ERROR_ANALYSIS = resolve_class(
        tabular_checks.ModelErrorAnalysis
    )
    TABULAR_PERFORMANCE_REPORT = resolve_class(tabular_checks.PerformanceReport)
    TABULAR_SIMPLE_MODEL_COMPARISON = resolve_class(
        tabular_checks.SimpleModelComparison
    )
    TABULAR_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
        tabular_checks.TrainTestPredictionDrift
    )
    TABULAR_UNUSED_FEATURES = resolve_class(tabular_checks.UnusedFeatures)

    VISION_CLASS_PERFORMANCE = resolve_class(vision_checks.ClassPerformance)
    VISION_MODEL_ERROR_ANALYSIS = resolve_class(
        vision_checks.ModelErrorAnalysis
    )
    VISION_SIMPLE_MODEL_COMPARISON = resolve_class(
        vision_checks.SimpleModelComparison
    )
    VISION_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
        vision_checks.TrainTestPredictionDrift
    )
DeepchecksModelValidationCheck (DeepchecksValidationCheck)

Categories of Deepchecks model validation checks.

This list includes a subset of the model evaluation checks provided by Deepchecks that require a single dataset and a mandatory model as input:

All these checks inherit from deepchecks.tabular.SingleDatasetCheck or `deepchecks.vision.SingleDatasetCheck and require a dataset and a mandatory model as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks model validation checks.

    This list includes a subset of the model evaluation checks provided by
    Deepchecks that require a single dataset and a mandatory model as input:

      * [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
      * [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)

    All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
    `deepchecks.vision.SingleDatasetCheck and require a dataset and a mandatory
    model as input.
    """

    TABULAR_CALIBRATION_SCORE = resolve_class(tabular_checks.CalibrationScore)
    TABULAR_CONFUSION_MATRIX_REPORT = resolve_class(
        tabular_checks.ConfusionMatrixReport
    )
    TABULAR_MODEL_INFERENCE_TIME = resolve_class(
        tabular_checks.ModelInferenceTime
    )
    TABULAR_REGRESSION_ERROR_DISTRIBUTION = resolve_class(
        tabular_checks.RegressionErrorDistribution
    )
    TABULAR_REGRESSION_SYSTEMATIC_ERROR = resolve_class(
        tabular_checks.RegressionSystematicError
    )
    TABULAR_ROC_REPORT = resolve_class(tabular_checks.RocReport)
    TABULAR_SEGMENT_PERFORMANCE = resolve_class(
        tabular_checks.SegmentPerformance
    )

    VISION_CONFUSION_MATRIX_REPORT = resolve_class(
        vision_checks.ConfusionMatrixReport
    )
    VISION_IMAGE_SEGMENT_PERFORMANCE = resolve_class(
        vision_checks.ImageSegmentPerformance
    )
    VISION_MEAN_AVERAGE_PRECISION_REPORT = resolve_class(
        vision_checks.MeanAveragePrecisionReport
    )
    VISION_MEAN_AVERAGE_RECALL_REPORT = resolve_class(
        vision_checks.MeanAverageRecallReport
    )
    VISION_ROBUSTNESS_REPORT = resolve_class(vision_checks.RobustnessReport)
    VISION_SINGLE_DATASET_SCALAR_PERFORMANCE = resolve_class(
        vision_checks.SingleDatasetScalarPerformance
    )
DeepchecksValidationCheck (StrEnum)

Base class for all Deepchecks categories of validation checks.

This base class defines some conventions used for all enum values used to identify the various validation checks that can be performed with Deepchecks:

  • enum values represent fully formed class paths pointing to Deepchecks BaseCheck subclasses
  • all tabular data checks are located under the deepchecks.tabular.checks module sub-tree
  • all computer vision data checks are located under the deepchecks.vision.checks module sub-tree
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksValidationCheck(StrEnum):
    """Base class for all Deepchecks categories of validation checks.

    This base class defines some conventions used for all enum values used to
    identify the various validation checks that can be performed with
    Deepchecks:

      * enum values represent fully formed class paths pointing to Deepchecks
      BaseCheck subclasses
      * all tabular data checks are located under the
      `deepchecks.tabular.checks` module sub-tree
      * all computer vision data checks are located under the
      `deepchecks.vision.checks` module sub-tree
    """

    @classmethod
    def validate_check_name(cls, check_name: str) -> None:
        """Validate a Deepchecks check identifier.

        Args:
            check_name: Identifies a builtin Deepchecks check. The identifier
                must be formatted as `deepchecks.{tabular|vision}.checks.<...>.<class-name>`.

        Raises:
            ValueError: If the check identifier does not follow the convention
                used by ZenML to identify Deepchecks builtin checks.
        """
        if not re.match(
            r"^deepchecks\.(tabular|vision)\.checks\.",
            check_name,
        ):
            raise ValueError(
                f"The supplied Deepcheck check identifier does not follow the "
                f"convention used by ZenML: `{check_name}`. The identifier "
                f"must be formatted as `deepchecks.<tabular|vision>.checks...` "
                f"and must be resolvable to a valid Deepchecks BaseCheck "
                f"subclass."
            )

    @classmethod
    def is_tabular_check(cls, check_name: str) -> bool:
        """Check if a validation check is applicable to tabular data.

        Args:
            check_name: Identifies a builtin Deepchecks check.

        Returns:
            True if the check is applicable to tabular data, otherwise False.
        """
        cls.validate_check_name(check_name)
        return check_name.startswith("deepchecks.tabular.")

    @classmethod
    def is_vision_check(cls, check_name: str) -> bool:
        """Check if a validation check is applicable to computer vision data.

        Args:
            check_name: Identifies a builtin Deepchecks check.

        Returns:
            True if the check is applicable to compute vision data, otherwise
            False.
        """
        cls.validate_check_name(check_name)
        return check_name.startswith("deepchecks.vision.")

    @classmethod
    def get_check_class(cls, check_name: str) -> Type[BaseCheck]:
        """Get the Deepchecks check class associated with an enum value or a custom check name.

        Args:
            check_name: Identifies a builtin Deepchecks check. The identifier
                must be formatted as `deepchecks.{tabular|vision}.checks.<class-name>`
                and must be resolvable to a valid Deepchecks BaseCheck class.

        Returns:
            The Deepchecks check class associated with this enum value.

        Raises:
            ValueError: If the check name could not be converted to a valid
                Deepchecks check class. This can happen for example if the enum
                values fall out of sync with the Deepchecks code base or if a
                custom check name is supplied that cannot be resolved to a valid
                Deepchecks BaseCheck class.
        """
        cls.validate_check_name(check_name)

        try:
            check_class = import_class_by_path(check_name)
        except AttributeError:
            raise ValueError(
                f"Could not map the `{check_name}` check identifier to a valid "
                f"Deepchecks check class."
            )

        if not issubclass(check_class, BaseCheck):
            raise ValueError(
                f"The `{check_name}` check identifier is mapped to an invalid "
                f"data type. Expected a {str(BaseCheck)} subclass, but instead "
                f"got: {str(check_class)}."
            )

        if check_name not in cls.values():
            logger.warning(
                f"You are using a custom Deepchecks check identifier that is "
                f"not listed in the `{str(cls)}` enum type. This could lead "
                f"to unexpected behavior."
            )

        return check_class

    @property
    def check_class(self) -> Type[BaseCheck]:
        """Convert the enum value to a valid Deepchecks check class.

        Returns:
            The Deepchecks check class associated with the enum value.
        """
        return self.get_check_class(self.value)

visualizers special

Deepchecks visualizer.

deepchecks_visualizer

Implementation of the Deepchecks visualizer.

DeepchecksVisualizer (BaseStepVisualizer)

The implementation of a Deepchecks Visualizer.

Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
class DeepchecksVisualizer(BaseStepVisualizer):
    """The implementation of a Deepchecks Visualizer."""

    @abstractmethod
    def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
        """Method to visualize components.

        Args:
            object: StepView fetched from run.get_step().
            *args: Additional arguments (unused).
            **kwargs: Additional keyword arguments (unused).
        """
        for artifact_view in object.outputs.values():
            # filter out anything but data analysis artifacts
            if artifact_view.type == DataAnalysisArtifact.__name__:
                artifact = artifact_view.read()
                self.generate_report(artifact)

    def generate_report(self, result: Union[CheckResult, SuiteResult]) -> None:
        """Generate a Deepchecks Report.

        Args:
            result: A SuiteResult.
        """
        print(result)

        if Environment.in_notebook():
            result.show()
        else:
            logger.warning(
                "The magic functions are only usable in a Jupyter notebook."
            )
            with tempfile.NamedTemporaryFile(
                mode="w", delete=False, suffix=".html", encoding="utf-8"
            ) as f:
                result.save_as_html(f)
                url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
generate_report(self, result)

Generate a Deepchecks Report.

Parameters:

Name Type Description Default
result Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult]

A SuiteResult.

required
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
def generate_report(self, result: Union[CheckResult, SuiteResult]) -> None:
    """Generate a Deepchecks Report.

    Args:
        result: A SuiteResult.
    """
    print(result)

    if Environment.in_notebook():
        result.show()
    else:
        logger.warning(
            "The magic functions are only usable in a Jupyter notebook."
        )
        with tempfile.NamedTemporaryFile(
            mode="w", delete=False, suffix=".html", encoding="utf-8"
        ) as f:
            result.save_as_html(f)
            url = f"file:///{f.name}"
        logger.info("Opening %s in a new browser.." % f.name)
        webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)

Method to visualize components.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
*args Any

Additional arguments (unused).

()
**kwargs Any

Additional keyword arguments (unused).

{}
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
    """Method to visualize components.

    Args:
        object: StepView fetched from run.get_step().
        *args: Additional arguments (unused).
        **kwargs: Additional keyword arguments (unused).
    """
    for artifact_view in object.outputs.values():
        # filter out anything but data analysis artifacts
        if artifact_view.type == DataAnalysisArtifact.__name__:
            artifact = artifact_view.read()
            self.generate_report(artifact)

evidently special

Initialization of the Evidently integration.

The Evidently integration provides a way to monitor your models in production. It includes a way to detect data drift and different kinds of model performance issues.

The results of Evidently calculations can either be exported as an interactive dashboard (visualized as an html file or in your Jupyter notebook), or as a JSON file.

EvidentlyIntegration (Integration)

Evidently integration for ZenML.

Source code in zenml/integrations/evidently/__init__.py
class EvidentlyIntegration(Integration):
    """[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML."""

    NAME = EVIDENTLY
    REQUIREMENTS = ["evidently==0.1.52dev0"]

    @staticmethod
    def activate() -> None:
        """Activate the Deepchecks integration."""
        from zenml.integrations.evidently import materializers  # noqa
        from zenml.integrations.evidently import visualizers  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Great Expectations integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=EVIDENTLY_DATA_VALIDATOR_FLAVOR,
                source="zenml.integrations.evidently.data_validators.EvidentlyDataValidator",
                type=StackComponentType.DATA_VALIDATOR,
                integration=cls.NAME,
            ),
        ]
activate() staticmethod

Activate the Deepchecks integration.

Source code in zenml/integrations/evidently/__init__.py
@staticmethod
def activate() -> None:
    """Activate the Deepchecks integration."""
    from zenml.integrations.evidently import materializers  # noqa
    from zenml.integrations.evidently import visualizers  # noqa
flavors() classmethod

Declare the stack component flavors for the Great Expectations integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/evidently/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Great Expectations integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=EVIDENTLY_DATA_VALIDATOR_FLAVOR,
            source="zenml.integrations.evidently.data_validators.EvidentlyDataValidator",
            type=StackComponentType.DATA_VALIDATOR,
            integration=cls.NAME,
        ),
    ]

data_validators special

Initialization of the Evidently data validator for ZenML.

evidently_data_validator

Implementation of the Evidently data validator.

EvidentlyDataValidator (BaseDataValidator) pydantic-model

Evidently data validator stack component.

Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
class EvidentlyDataValidator(BaseDataValidator):
    """Evidently data validator stack component."""

    # Class Configuration
    FLAVOR: ClassVar[str] = EVIDENTLY_DATA_VALIDATOR_FLAVOR
    NAME: ClassVar[str] = "Evidently"

    @classmethod
    def _unpack_options(
        cls, option_list: Sequence[Tuple[str, Dict[str, Any]]]
    ) -> Sequence[Any]:
        """Unpack Evidently options.

        Implements de-serialization for [Evidently options](https://docs.evidentlyai.com/user-guide/customization)
        that can be passed as constructor arguments when creating Profile and
        Dashboard objects. The convention used is that each item in the list
        consists of two elements:

        * a string containing the full class path of a `dataclass` based
        class with Evidently options
        * a dictionary with kwargs used as parameters for the option instance

        Example:

        ```python
            options = [
                (
                    "evidently.options.ColorOptions",{
                        "primary_color": "#5a86ad",
                        "fill_color": "#fff4f2",
                        "zero_line_color": "#016795",
                        "current_data_color": "#c292a1",
                        "reference_data_color": "#017b92",
                    }
                ),
            ]
        ```

        This is the same as saying:

        ```python
        from evidently.options import ColorOptions

        color_scheme = ColorOptions()
        color_scheme.primary_color = "#5a86ad"
        color_scheme.fill_color = "#fff4f2"
        color_scheme.zero_line_color = "#016795"
        color_scheme.current_data_color = "#c292a1"
        color_scheme.reference_data_color = "#017b92"
        ```

        Args:
            option_list: list of packed Evidently options

        Returns:
            A list of unpacked Evidently options

        Raises:
            ValueError: if one of the passed Evidently class paths cannot be
                resolved to an actual class.
        """
        options = []
        for option_clspath, option_args in option_list:
            try:
                option_cls = load_source_path_class(option_clspath)
            except AttributeError:
                raise ValueError(
                    f"Could not map the `{option_clspath}` Evidently option "
                    f"class path to a valid class."
                )
            option = option_cls(**option_args)
            options.append(option)

        return options

    def data_profiling(
        self,
        dataset: pd.DataFrame,
        comparison_dataset: Optional[pd.DataFrame] = None,
        profile_list: Optional[Sequence[str]] = None,
        column_mapping: Optional[ColumnMapping] = None,
        verbose_level: int = 1,
        profile_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
        dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
        **kwargs: Any,
    ) -> Tuple[Profile, Dashboard]:
        """Analyze a dataset and generate a data profile with Evidently.

        The method takes in an optional list of Evidently options to be passed
        to the profile constructor (`profile_options`) and the dashboard
        constructor (`dashboard_options`). Each element in the list must be
        composed of two items: the first is a full class path of an Evidently
        option `dataclass`, the second is a dictionary of kwargs with the actual
        option parameters, e.g.:

        ```python
        options = [
            (
                "evidently.options.ColorOptions",{
                    "primary_color": "#5a86ad",
                    "fill_color": "#fff4f2",
                    "zero_line_color": "#016795",
                    "current_data_color": "#c292a1",
                    "reference_data_color": "#017b92",
                }
            ),
        ]
        ```

        Args:
            dataset: Target dataset to be profiled.
            comparison_dataset: Optional dataset to be used for data profiles
                that require a baseline for comparison (e.g data drift profiles).
            profile_list: Optional list identifying the categories of Evidently
                data profiles to be generated.
            column_mapping: Properties of the DataFrame columns used
            verbose_level: Level of verbosity for the Evidently dashboards. Use
                0 for a brief dashboard, 1 for a detailed dashboard.
            profile_options: Optional list of options to pass to the
                profile constructor.
            dashboard_options: Optional list of options to pass to the
                dashboard constructor.
            **kwargs: Extra keyword arguments (unused).

        Returns:
            The Evidently Profile and Dashboard objects corresponding to the set
            of generated profiles.
        """
        sections, tabs = get_profile_sections_and_tabs(
            profile_list, verbose_level
        )
        unpacked_profile_options = self._unpack_options(profile_options)
        unpacked_dashboard_options = self._unpack_options(dashboard_options)

        dashboard = Dashboard(tabs=tabs, options=unpacked_dashboard_options)
        dashboard.calculate(
            reference_data=dataset,
            current_data=comparison_dataset,
            column_mapping=column_mapping,
        )
        profile = Profile(sections=sections, options=unpacked_profile_options)
        profile.calculate(
            reference_data=dataset,
            current_data=comparison_dataset,
            column_mapping=column_mapping,
        )
        return profile, dashboard
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, column_mapping=None, verbose_level=1, profile_options=[], dashboard_options=[], **kwargs)

Analyze a dataset and generate a data profile with Evidently.

The method takes in an optional list of Evidently options to be passed to the profile constructor (profile_options) and the dashboard constructor (dashboard_options). Each element in the list must be composed of two items: the first is a full class path of an Evidently option dataclass, the second is a dictionary of kwargs with the actual option parameters, e.g.:

options = [
    (
        "evidently.options.ColorOptions",{
            "primary_color": "#5a86ad",
            "fill_color": "#fff4f2",
            "zero_line_color": "#016795",
            "current_data_color": "#c292a1",
            "reference_data_color": "#017b92",
        }
    ),
]

Parameters:

Name Type Description Default
dataset DataFrame

Target dataset to be profiled.

required
comparison_dataset Optional[pandas.core.frame.DataFrame]

Optional dataset to be used for data profiles that require a baseline for comparison (e.g data drift profiles).

None
profile_list Optional[Sequence[str]]

Optional list identifying the categories of Evidently data profiles to be generated.

None
column_mapping Optional[evidently.pipeline.column_mapping.ColumnMapping]

Properties of the DataFrame columns used

None
verbose_level int

Level of verbosity for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard.

1
profile_options Sequence[Tuple[str, Dict[str, Any]]]

Optional list of options to pass to the profile constructor.

[]
dashboard_options Sequence[Tuple[str, Dict[str, Any]]]

Optional list of options to pass to the dashboard constructor.

[]
**kwargs Any

Extra keyword arguments (unused).

{}

Returns:

Type Description
Tuple[evidently.model_profile.model_profile.Profile, evidently.dashboard.dashboard.Dashboard]

The Evidently Profile and Dashboard objects corresponding to the set of generated profiles.

Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
def data_profiling(
    self,
    dataset: pd.DataFrame,
    comparison_dataset: Optional[pd.DataFrame] = None,
    profile_list: Optional[Sequence[str]] = None,
    column_mapping: Optional[ColumnMapping] = None,
    verbose_level: int = 1,
    profile_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
    dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
    **kwargs: Any,
) -> Tuple[Profile, Dashboard]:
    """Analyze a dataset and generate a data profile with Evidently.

    The method takes in an optional list of Evidently options to be passed
    to the profile constructor (`profile_options`) and the dashboard
    constructor (`dashboard_options`). Each element in the list must be
    composed of two items: the first is a full class path of an Evidently
    option `dataclass`, the second is a dictionary of kwargs with the actual
    option parameters, e.g.:

    ```python
    options = [
        (
            "evidently.options.ColorOptions",{
                "primary_color": "#5a86ad",
                "fill_color": "#fff4f2",
                "zero_line_color": "#016795",
                "current_data_color": "#c292a1",
                "reference_data_color": "#017b92",
            }
        ),
    ]
    ```

    Args:
        dataset: Target dataset to be profiled.
        comparison_dataset: Optional dataset to be used for data profiles
            that require a baseline for comparison (e.g data drift profiles).
        profile_list: Optional list identifying the categories of Evidently
            data profiles to be generated.
        column_mapping: Properties of the DataFrame columns used
        verbose_level: Level of verbosity for the Evidently dashboards. Use
            0 for a brief dashboard, 1 for a detailed dashboard.
        profile_options: Optional list of options to pass to the
            profile constructor.
        dashboard_options: Optional list of options to pass to the
            dashboard constructor.
        **kwargs: Extra keyword arguments (unused).

    Returns:
        The Evidently Profile and Dashboard objects corresponding to the set
        of generated profiles.
    """
    sections, tabs = get_profile_sections_and_tabs(
        profile_list, verbose_level
    )
    unpacked_profile_options = self._unpack_options(profile_options)
    unpacked_dashboard_options = self._unpack_options(dashboard_options)

    dashboard = Dashboard(tabs=tabs, options=unpacked_dashboard_options)
    dashboard.calculate(
        reference_data=dataset,
        current_data=comparison_dataset,
        column_mapping=column_mapping,
    )
    profile = Profile(sections=sections, options=unpacked_profile_options)
    profile.calculate(
        reference_data=dataset,
        current_data=comparison_dataset,
        column_mapping=column_mapping,
    )
    return profile, dashboard
get_profile_sections_and_tabs(profile_list, verbose_level=1)

Get the profile sections and dashboard tabs for a profile list.

Parameters:

Name Type Description Default
profile_list Optional[Sequence[str]]

List of identifiers for Evidently profiles.

required
verbose_level int

Verbosity level for the rendered dashboard. Use 0 for a brief dashboard, 1 for a detailed dashboard.

1

Returns:

Type Description
Tuple[List[evidently.model_profile.sections.base_profile_section.ProfileSection], List[evidently.dashboard.tabs.base_tab.Tab]]

A tuple of two lists of profile sections and tabs.

Exceptions:

Type Description
ValueError

if the profile_section is not supported.

Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
def get_profile_sections_and_tabs(
    profile_list: Optional[Sequence[str]],
    verbose_level: int = 1,
) -> Tuple[List[ProfileSection], List[Tab]]:
    """Get the profile sections and dashboard tabs for a profile list.

    Args:
        profile_list: List of identifiers for Evidently profiles.
        verbose_level: Verbosity level for the rendered dashboard. Use
            0 for a brief dashboard, 1 for a detailed dashboard.

    Returns:
        A tuple of two lists of profile sections and tabs.

    Raises:
        ValueError: if the profile_section is not supported.
    """
    profile_list = profile_list or list(profile_mapper.keys())
    try:
        return (
            [profile_mapper[profile]() for profile in profile_list],
            [
                dashboard_mapper[profile](verbose_level=verbose_level)
                for profile in profile_list
            ],
        )
    except KeyError as e:
        nl = "\n"
        raise ValueError(
            f"Invalid profile sections: {profile_list} \n\n"
            f"Valid and supported options are: {nl}- "
            f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
        ) from e

materializers special

Evidently materializers.

evidently_profile_materializer

Implementation of Evidently profile materializer.

EvidentlyProfileMaterializer (BaseMaterializer)

Materializer to read data to and from an Evidently Profile.

Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
class EvidentlyProfileMaterializer(BaseMaterializer):
    """Materializer to read data to and from an Evidently Profile."""

    ASSOCIATED_TYPES = (Profile,)
    ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Profile:
        """Reads an Evidently Profile object from a json file.

        Args:
            data_type: The type of the data to read.

        Returns:
            The Evidently Profile

        Raises:
            TypeError: if the json file contains an invalid data type.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        contents = yaml_utils.read_json(filepath)
        if type(contents) != dict:
            raise TypeError(
                f"Contents {contents} was type {type(contents)} but expected "
                f"dictionary"
            )

        section_types = contents.pop("section_types", [])
        sections = []
        for section_type in section_types:
            section_cls = import_class_by_path(section_type)
            section = section_cls()
            section._result = contents[section.part_id()]
            sections.append(section)

        return Profile(sections=sections)

    def handle_return(self, data: Profile) -> None:
        """Serialize an Evidently Profile to a json file.

        Args:
            data: The Evidently Profile to be serialized.
        """
        super().handle_return(data)

        contents = data.object()
        # include the list of profile sections in the serialized dictionary,
        # so we'll be able to re-create them during de-serialization
        contents["section_types"] = [
            resolve_class(stage.__class__) for stage in data.stages
        ]

        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        yaml_utils.write_json(filepath, contents, encoder=NumpyEncoder)
handle_input(self, data_type)

Reads an Evidently Profile object from a json file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Profile

The Evidently Profile

Exceptions:

Type Description
TypeError

if the json file contains an invalid data type.

Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
def handle_input(self, data_type: Type[Any]) -> Profile:
    """Reads an Evidently Profile object from a json file.

    Args:
        data_type: The type of the data to read.

    Returns:
        The Evidently Profile

    Raises:
        TypeError: if the json file contains an invalid data type.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    contents = yaml_utils.read_json(filepath)
    if type(contents) != dict:
        raise TypeError(
            f"Contents {contents} was type {type(contents)} but expected "
            f"dictionary"
        )

    section_types = contents.pop("section_types", [])
    sections = []
    for section_type in section_types:
        section_cls = import_class_by_path(section_type)
        section = section_cls()
        section._result = contents[section.part_id()]
        sections.append(section)

    return Profile(sections=sections)
handle_return(self, data)

Serialize an Evidently Profile to a json file.

Parameters:

Name Type Description Default
data Profile

The Evidently Profile to be serialized.

required
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
def handle_return(self, data: Profile) -> None:
    """Serialize an Evidently Profile to a json file.

    Args:
        data: The Evidently Profile to be serialized.
    """
    super().handle_return(data)

    contents = data.object()
    # include the list of profile sections in the serialized dictionary,
    # so we'll be able to re-create them during de-serialization
    contents["section_types"] = [
        resolve_class(stage.__class__) for stage in data.stages
    ]

    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    yaml_utils.write_json(filepath, contents, encoder=NumpyEncoder)

steps special

Initialization of the Evidently Standard Steps.

evidently_profile

Implementation of the Evidently Profile Step.

EvidentlyColumnMapping (BaseModel) pydantic-model

Column mapping configuration for Evidently.

This class is a 1-to-1 serializable analogue of Evidently's ColumnMapping data type that can be used as a step configuration field (see https://docs.evidentlyai.com/features/dashboards/column_mapping).

Attributes:

Name Type Description
target Optional[str]

target column

prediction Union[str, Sequence[str]]

target column

datetime Optional[str]

datetime column

id Optional[str]

id column

numerical_features Optional[List[str]]

numerical features

categorical_features Optional[List[str]]

categorical features

datetime_features Optional[List[str]]

datetime features

target_names Optional[List[str]]

target column names

task Optional[Literal['classification', 'regression']]

model task (regression or classification)

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyColumnMapping(BaseModel):
    """Column mapping configuration for Evidently.

    This class is a 1-to-1 serializable analogue of Evidently's
    ColumnMapping data type that can be used as a step configuration field
    (see https://docs.evidentlyai.com/features/dashboards/column_mapping).

    Attributes:
        target: target column
        prediction: target column
        datetime: datetime column
        id: id column
        numerical_features: numerical features
        categorical_features: categorical features
        datetime_features: datetime features
        target_names: target column names
        task: model task (regression or classification)
    """

    target: Optional[str] = None
    prediction: Optional[Union[str, Sequence[str]]] = None
    datetime: Optional[str] = None
    id: Optional[str] = None
    numerical_features: Optional[List[str]] = None
    categorical_features: Optional[List[str]] = None
    datetime_features: Optional[List[str]] = None
    target_names: Optional[List[str]] = None
    task: Optional[Literal["classification", "regression"]] = None

    def to_evidently_column_mapping(self) -> ColumnMapping:
        """Convert this Pydantic object to an Evidently ColumnMapping object.

        Returns:
            An Evidently column mapping converted from this Pydantic object.
        """
        column_mapping = ColumnMapping()

        # preserve the Evidently defaults where possible
        column_mapping.target = self.target or column_mapping.target
        column_mapping.prediction = self.prediction or column_mapping.prediction
        column_mapping.datetime = self.datetime or column_mapping.datetime
        column_mapping.id = self.id or column_mapping.id
        column_mapping.numerical_features = (
            self.numerical_features or column_mapping.numerical_features
        )
        column_mapping.datetime_features = (
            self.datetime_features or column_mapping.datetime_features
        )
        column_mapping.target_names = (
            self.target_names or column_mapping.target_names
        )
        column_mapping.task = self.task or column_mapping.task

        return column_mapping
to_evidently_column_mapping(self)

Convert this Pydantic object to an Evidently ColumnMapping object.

Returns:

Type Description
ColumnMapping

An Evidently column mapping converted from this Pydantic object.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
def to_evidently_column_mapping(self) -> ColumnMapping:
    """Convert this Pydantic object to an Evidently ColumnMapping object.

    Returns:
        An Evidently column mapping converted from this Pydantic object.
    """
    column_mapping = ColumnMapping()

    # preserve the Evidently defaults where possible
    column_mapping.target = self.target or column_mapping.target
    column_mapping.prediction = self.prediction or column_mapping.prediction
    column_mapping.datetime = self.datetime or column_mapping.datetime
    column_mapping.id = self.id or column_mapping.id
    column_mapping.numerical_features = (
        self.numerical_features or column_mapping.numerical_features
    )
    column_mapping.datetime_features = (
        self.datetime_features or column_mapping.datetime_features
    )
    column_mapping.target_names = (
        self.target_names or column_mapping.target_names
    )
    column_mapping.task = self.task or column_mapping.task

    return column_mapping
EvidentlyProfileConfig (BaseDriftDetectionConfig) pydantic-model

Config class for Evidently profile steps.

Attributes:

Name Type Description
column_mapping Optional[zenml.integrations.evidently.steps.evidently_profile.EvidentlyColumnMapping]

properties of the DataFrame columns used

profile_sections Optional[Sequence[str]]

a list identifying the Evidently profile sections to be used. The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance"

verbose_level int

Verbosity level for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard.

profile_options Sequence[Tuple[str, Dict[str, Any]]]

Optional list of options to pass to the profile constructor. See EvidentlyDataValidator._unpack_options.

dashboard_options Sequence[Tuple[str, Dict[str, Any]]]

Optional list of options to pass to the dashboard constructor. See EvidentlyDataValidator._unpack_options.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
    """Config class for Evidently profile steps.

    Attributes:
        column_mapping: properties of the DataFrame columns used
        profile_sections: a list identifying the Evidently profile sections to be
            used. The following are valid options supported by Evidently:
            - "datadrift"
            - "categoricaltargetdrift"
            - "numericaltargetdrift"
            - "classificationmodelperformance"
            - "regressionmodelperformance"
            - "probabilisticmodelperformance"
        verbose_level: Verbosity level for the Evidently dashboards. Use
            0 for a brief dashboard, 1 for a detailed dashboard.
        profile_options: Optional list of options to pass to the
            profile constructor. See `EvidentlyDataValidator._unpack_options`.
        dashboard_options: Optional list of options to pass to the
            dashboard constructor. See `EvidentlyDataValidator._unpack_options`.
    """

    column_mapping: Optional[EvidentlyColumnMapping] = None
    profile_sections: Optional[Sequence[str]] = None
    verbose_level: int = 1
    profile_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
        default_factory=list
    )
    dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
        default_factory=list
    )
EvidentlyProfileStep (BaseDriftDetectionStep)

Step implementation implementing an Evidently Profile Step.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileStep(BaseDriftDetectionStep):
    """Step implementation implementing an Evidently Profile Step."""

    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: pd.DataFrame,
        comparison_dataset: pd.DataFrame,
        config: EvidentlyProfileConfig,
    ) -> Output(  # type:ignore[valid-type]
        profile=Profile, dashboard=str
    ):
        """Main entrypoint for the Evidently categorical target drift detection step.

        Args:
            reference_dataset: a Pandas DataFrame
            comparison_dataset: a Pandas DataFrame of new data you wish to
                compare against the reference data
            config: the configuration for the step

        Returns:
            profile: Evidently Profile generated for the data drift
            dashboard: HTML report extracted from an Evidently Dashboard
              generated for the data drift
        """
        data_validator = cast(
            EvidentlyDataValidator,
            EvidentlyDataValidator.get_active_data_validator(),
        )
        column_mapping = None
        if config.column_mapping:
            column_mapping = config.column_mapping.to_evidently_column_mapping()
        profile, dashboard = data_validator.data_profiling(
            dataset=reference_dataset,
            comparison_dataset=comparison_dataset,
            profile_list=config.profile_sections,
            column_mapping=column_mapping,
            verbose_level=config.verbose_level,
            profile_options=config.profile_options,
            dashboard_options=config.dashboard_options,
        )
        return [profile, dashboard.html()]
CONFIG_CLASS (BaseDriftDetectionConfig) pydantic-model

Config class for Evidently profile steps.

Attributes:

Name Type Description
column_mapping Optional[zenml.integrations.evidently.steps.evidently_profile.EvidentlyColumnMapping]

properties of the DataFrame columns used

profile_sections Optional[Sequence[str]]

a list identifying the Evidently profile sections to be used. The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance"

verbose_level int

Verbosity level for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard.

profile_options Sequence[Tuple[str, Dict[str, Any]]]

Optional list of options to pass to the profile constructor. See EvidentlyDataValidator._unpack_options.

dashboard_options Sequence[Tuple[str, Dict[str, Any]]]

Optional list of options to pass to the dashboard constructor. See EvidentlyDataValidator._unpack_options.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
    """Config class for Evidently profile steps.

    Attributes:
        column_mapping: properties of the DataFrame columns used
        profile_sections: a list identifying the Evidently profile sections to be
            used. The following are valid options supported by Evidently:
            - "datadrift"
            - "categoricaltargetdrift"
            - "numericaltargetdrift"
            - "classificationmodelperformance"
            - "regressionmodelperformance"
            - "probabilisticmodelperformance"
        verbose_level: Verbosity level for the Evidently dashboards. Use
            0 for a brief dashboard, 1 for a detailed dashboard.
        profile_options: Optional list of options to pass to the
            profile constructor. See `EvidentlyDataValidator._unpack_options`.
        dashboard_options: Optional list of options to pass to the
            dashboard constructor. See `EvidentlyDataValidator._unpack_options`.
    """

    column_mapping: Optional[EvidentlyColumnMapping] = None
    profile_sections: Optional[Sequence[str]] = None
    verbose_level: int = 1
    profile_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
        default_factory=list
    )
    dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
        default_factory=list
    )
entrypoint(self, reference_dataset, comparison_dataset, config)

Main entrypoint for the Evidently categorical target drift detection step.

Parameters:

Name Type Description Default
reference_dataset DataFrame

a Pandas DataFrame

required
comparison_dataset DataFrame

a Pandas DataFrame of new data you wish to compare against the reference data

required
config EvidentlyProfileConfig

the configuration for the step

required

Returns:

Type Description
profile

Evidently Profile generated for the data drift dashboard: HTML report extracted from an Evidently Dashboard generated for the data drift

Source code in zenml/integrations/evidently/steps/evidently_profile.py
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: pd.DataFrame,
    comparison_dataset: pd.DataFrame,
    config: EvidentlyProfileConfig,
) -> Output(  # type:ignore[valid-type]
    profile=Profile, dashboard=str
):
    """Main entrypoint for the Evidently categorical target drift detection step.

    Args:
        reference_dataset: a Pandas DataFrame
        comparison_dataset: a Pandas DataFrame of new data you wish to
            compare against the reference data
        config: the configuration for the step

    Returns:
        profile: Evidently Profile generated for the data drift
        dashboard: HTML report extracted from an Evidently Dashboard
          generated for the data drift
    """
    data_validator = cast(
        EvidentlyDataValidator,
        EvidentlyDataValidator.get_active_data_validator(),
    )
    column_mapping = None
    if config.column_mapping:
        column_mapping = config.column_mapping.to_evidently_column_mapping()
    profile, dashboard = data_validator.data_profiling(
        dataset=reference_dataset,
        comparison_dataset=comparison_dataset,
        profile_list=config.profile_sections,
        column_mapping=column_mapping,
        verbose_level=config.verbose_level,
        profile_options=config.profile_options,
        dashboard_options=config.dashboard_options,
    )
    return [profile, dashboard.html()]
evidently_profile_step(step_name, config)

Shortcut function to create a new instance of the EvidentlyProfileConfig step.

The returned EvidentlyProfileStep can be used in a pipeline to run model drift analyses on two input pd.DataFrame datasets and return the results as an Evidently profile object and a rendered dashboard object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config EvidentlyProfileConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a EvidentlyProfileStep step instance

Source code in zenml/integrations/evidently/steps/evidently_profile.py
def evidently_profile_step(
    step_name: str,
    config: EvidentlyProfileConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the EvidentlyProfileConfig step.

    The returned EvidentlyProfileStep can be used in a pipeline to
    run model drift analyses on two input pd.DataFrame datasets and return the
    results as an Evidently profile object and a rendered dashboard object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a EvidentlyProfileStep step instance
    """
    return clone_step(EvidentlyProfileStep, step_name)(config=config)

visualizers special

Initialization for Evidently visualizer.

evidently_visualizer

Implementation of the Evidently visualizer.

EvidentlyVisualizer (BaseStepVisualizer)

The implementation of an Evidently Visualizer.

Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
class EvidentlyVisualizer(BaseStepVisualizer):
    """The implementation of an Evidently Visualizer."""

    @abstractmethod
    def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
        """Method to visualize components.

        Args:
            object: StepView fetched from run.get_step().
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        for artifact_view in object.outputs.values():
            # filter out anything but data artifacts
            if (
                artifact_view.type == DataArtifact.__name__
                and artifact_view.data_type == "builtins.str"
            ):
                artifact = artifact_view.read()
                self.generate_facet(artifact)

    def generate_facet(self, html_: str) -> None:
        """Generate a Facet Overview.

        Args:
            html_: HTML represented as a string.
        """
        if Environment.in_notebook() or Environment.in_google_colab():
            from IPython.core.display import HTML, display

            display(HTML(html_))
        else:
            logger.warning(
                "The magic functions are only usable in a Jupyter notebook."
            )
            with tempfile.NamedTemporaryFile(
                mode="w", delete=False, suffix=".html", encoding="utf-8"
            ) as f:
                f.write(html_)
                url = f"file:///{f.name}"
                logger.info("Opening %s in a new browser.." % f.name)
                webbrowser.open(url, new=2)
generate_facet(self, html_)

Generate a Facet Overview.

Parameters:

Name Type Description Default
html_ str

HTML represented as a string.

required
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
def generate_facet(self, html_: str) -> None:
    """Generate a Facet Overview.

    Args:
        html_: HTML represented as a string.
    """
    if Environment.in_notebook() or Environment.in_google_colab():
        from IPython.core.display import HTML, display

        display(HTML(html_))
    else:
        logger.warning(
            "The magic functions are only usable in a Jupyter notebook."
        )
        with tempfile.NamedTemporaryFile(
            mode="w", delete=False, suffix=".html", encoding="utf-8"
        ) as f:
            f.write(html_)
            url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)

Method to visualize components.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
    """Method to visualize components.

    Args:
        object: StepView fetched from run.get_step().
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for artifact_view in object.outputs.values():
        # filter out anything but data artifacts
        if (
            artifact_view.type == DataArtifact.__name__
            and artifact_view.data_type == "builtins.str"
        ):
            artifact = artifact_view.read()
            self.generate_facet(artifact)

facets special

Facets integration for ZenML.

The Facets integration provides a simple way to visualize post-execution objects like PipelineView, PipelineRunView and StepView. These objects can be extended using the BaseVisualization class. This integration requires facets-overview be installed in your Python environment.

FacetsIntegration (Integration)

Definition of Facet integration for ZenML.

Source code in zenml/integrations/facets/__init__.py
class FacetsIntegration(Integration):
    """Definition of [Facet](https://pair-code.github.io/facets/) integration for ZenML."""

    NAME = FACETS
    REQUIREMENTS = ["facets-overview>=1.0.0", "IPython"]

visualizers special

Intitialization of the Facet Visualizer.

facet_statistics_visualizer

Implementation of the Facet Statistics Visualizer.

FacetStatisticsVisualizer (BaseStepVisualizer)

The base implementation of a ZenML Visualizer.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
class FacetStatisticsVisualizer(BaseStepVisualizer):
    """The base implementation of a ZenML Visualizer."""

    @abstractmethod
    def visualize(
        self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
    ) -> None:
        """Method to visualize components.

        Args:
            object: StepView fetched from run.get_step().
            magic: Whether to render in a Jupyter notebook or not.
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        datasets = []
        for output_name, artifact_view in object.outputs.items():
            df = artifact_view.read()
            if type(df) is not pd.DataFrame:
                logger.warning(
                    "`%s` is not a pd.DataFrame. You can only visualize "
                    "statistics of steps that output pandas DataFrames. "
                    "Skipping this output.." % output_name
                )
            else:
                datasets.append({"name": output_name, "table": df})
        h = self.generate_html(datasets)
        self.generate_facet(h, magic)

    def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
        """Generates html for facet.

        Args:
            datasets: List of dicts of DataFrames to be visualized as stats.

        Returns:
            HTML template with proto string embedded.
        """
        proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
            datasets
        )
        protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")

        template = os.path.join(
            os.path.abspath(os.path.dirname(__file__)),
            "stats.html",
        )
        html_template = io_utils.read_file_contents_as_string(template)

        html_ = html_template.replace("protostr", protostr)
        return html_

    def generate_facet(self, html_: str, magic: bool = False) -> None:
        """Generate a Facet Overview.

        Args:
            html_: HTML represented as a string.
            magic: Whether to magically materialize facet in a notebook.

        Raises:
            EnvironmentError: If magic is True and not in a notebook.
        """
        if magic:
            if not (Environment.in_notebook() or Environment.in_google_colab()):
                raise EnvironmentError(
                    "The magic functions are only usable in a Jupyter notebook."
                )
            display(HTML(html_))
        else:
            with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
                io_utils.write_file_contents_as_string(f.name, html_)
                url = f"file:///{f.name}"
                logger.info("Opening %s in a new browser.." % f.name)
                webbrowser.open(url, new=2)
generate_facet(self, html_, magic=False)

Generate a Facet Overview.

Parameters:

Name Type Description Default
html_ str

HTML represented as a string.

required
magic bool

Whether to magically materialize facet in a notebook.

False

Exceptions:

Type Description
EnvironmentError

If magic is True and not in a notebook.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_facet(self, html_: str, magic: bool = False) -> None:
    """Generate a Facet Overview.

    Args:
        html_: HTML represented as a string.
        magic: Whether to magically materialize facet in a notebook.

    Raises:
        EnvironmentError: If magic is True and not in a notebook.
    """
    if magic:
        if not (Environment.in_notebook() or Environment.in_google_colab()):
            raise EnvironmentError(
                "The magic functions are only usable in a Jupyter notebook."
            )
        display(HTML(html_))
    else:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
            io_utils.write_file_contents_as_string(f.name, html_)
            url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
generate_html(self, datasets)

Generates html for facet.

Parameters:

Name Type Description Default
datasets List[Dict[str, pandas.core.frame.DataFrame]]

List of dicts of DataFrames to be visualized as stats.

required

Returns:

Type Description
str

HTML template with proto string embedded.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
    """Generates html for facet.

    Args:
        datasets: List of dicts of DataFrames to be visualized as stats.

    Returns:
        HTML template with proto string embedded.
    """
    proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
        datasets
    )
    protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")

    template = os.path.join(
        os.path.abspath(os.path.dirname(__file__)),
        "stats.html",
    )
    html_template = io_utils.read_file_contents_as_string(template)

    html_ = html_template.replace("protostr", protostr)
    return html_
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize components.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
magic bool

Whether to render in a Jupyter notebook or not.

False
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
@abstractmethod
def visualize(
    self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
    """Method to visualize components.

    Args:
        object: StepView fetched from run.get_step().
        magic: Whether to render in a Jupyter notebook or not.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    datasets = []
    for output_name, artifact_view in object.outputs.items():
        df = artifact_view.read()
        if type(df) is not pd.DataFrame:
            logger.warning(
                "`%s` is not a pd.DataFrame. You can only visualize "
                "statistics of steps that output pandas DataFrames. "
                "Skipping this output.." % output_name
            )
        else:
            datasets.append({"name": output_name, "table": df})
    h = self.generate_html(datasets)
    self.generate_facet(h, magic)

feast special

Initialization for Feast integration.

The Feast integration offers a way to connect to a Feast Feature Store. ZenML implements a dedicated stack component that you can access as part of your ZenML steps in the usual ways.

FeastIntegration (Integration)

Definition of Feast integration for ZenML.

Source code in zenml/integrations/feast/__init__.py
class FeastIntegration(Integration):
    """Definition of Feast integration for ZenML."""

    NAME = FEAST
    REQUIREMENTS = ["feast[redis]>=0.19.4", "redis-server"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Feast integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=FEAST_FEATURE_STORE_FLAVOR,
                source="zenml.integrations.feast.feature_stores.FeastFeatureStore",
                type=StackComponentType.FEATURE_STORE,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Feast integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/feast/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Feast integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=FEAST_FEATURE_STORE_FLAVOR,
            source="zenml.integrations.feast.feature_stores.FeastFeatureStore",
            type=StackComponentType.FEATURE_STORE,
            integration=cls.NAME,
        )
    ]

feature_stores special

Feast Feature Store integration for ZenML.

Feature stores allow data teams to serve data via an offline store and an online low-latency store where data is kept in sync between the two. It also offers a centralized registry where features (and feature schemas) are stored for use within a team or wider organization. Feature stores are a relatively recent addition to commonly-used machine learning stacks. Feast is a leading open-source feature store, first developed by Gojek in collaboration with Google.

feast_feature_store

Implementation of the Feast Feature Store for ZenML.

FeastFeatureStore (BaseFeatureStore) pydantic-model

Class to interact with the Feast feature store.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
class FeastFeatureStore(BaseFeatureStore):
    """Class to interact with the Feast feature store."""

    FLAVOR: ClassVar[str] = FEAST_FEATURE_STORE_FLAVOR

    online_host: str = "localhost"
    online_port: int = 6379
    feast_repo: str

    def _validate_connection(self) -> None:
        """Validates the connection to the feature store.

        Raises:
            ConnectionError: If the online component (Redis) is not available.
        """
        client = redis.Redis(host=self.online_host, port=self.online_port)
        try:
            client.ping()
        except redis.exceptions.ConnectionError as e:
            raise redis.exceptions.ConnectionError(
                "Could not connect to feature store's online component. "
                "Please make sure that Redis is running."
            ) from e

    def get_historical_features(
        self,
        entity_df: Union[pd.DataFrame, str],
        features: List[str],
        full_feature_names: bool = False,
    ) -> pd.DataFrame:
        """Returns the historical features for training or batch scoring.

        Args:
            entity_df: The entity DataFrame or entity name.
            features: The features to retrieve.
            full_feature_names: Whether to return the full feature names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The historical features as a Pandas DataFrame.
        """
        fs = FeatureStore(repo_path=self.feast_repo)

        return fs.get_historical_features(
            entity_df=entity_df,
            features=features,
            full_feature_names=full_feature_names,
        ).to_df()

    def get_online_features(
        self,
        entity_rows: List[Dict[str, Any]],
        features: List[str],
        full_feature_names: bool = False,
    ) -> Dict[str, Any]:
        """Returns the latest online feature data.

        Args:
            entity_rows: The entity rows to retrieve.
            features: The features to retrieve.
            full_feature_names: Whether to return the full feature names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The latest online feature data as a dictionary.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)

        return fs.get_online_features(  # type: ignore[no-any-return]
            entity_rows=entity_rows,
            features=features,
            full_feature_names=full_feature_names,
        ).to_dict()

    def get_data_sources(self) -> List[str]:
        """Returns the data sources' names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The data sources' names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_data_sources()]

    def get_entities(self) -> List[str]:
        """Returns the entity names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The entity names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_entities()]

    def get_feature_services(self) -> List[str]:
        """Returns the feature service names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The feature service names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_feature_services()]

    def get_feature_views(self) -> List[str]:
        """Returns the feature view names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The feature view names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_feature_views()]

    def get_project(self) -> str:
        """Returns the project name.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The project name.
        """
        fs = FeatureStore(repo_path=self.feast_repo)
        return str(fs.project)

    def get_registry(self) -> Registry:
        """Returns the feature store registry.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The registry.
        """
        fs: FeatureStore = FeatureStore(repo_path=self.feast_repo)
        return fs.registry

    def get_feast_version(self) -> str:
        """Returns the version of Feast used.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The version of Feast currently being used.
        """
        fs = FeatureStore(repo_path=self.feast_repo)
        return str(fs.version())
get_data_sources(self)

Returns the data sources' names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The data sources' names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_data_sources(self) -> List[str]:
    """Returns the data sources' names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The data sources' names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_data_sources()]
get_entities(self)

Returns the entity names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The entity names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_entities(self) -> List[str]:
    """Returns the entity names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The entity names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_entities()]
get_feast_version(self)

Returns the version of Feast used.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
str

The version of Feast currently being used.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feast_version(self) -> str:
    """Returns the version of Feast used.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The version of Feast currently being used.
    """
    fs = FeatureStore(repo_path=self.feast_repo)
    return str(fs.version())
get_feature_services(self)

Returns the feature service names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The feature service names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_services(self) -> List[str]:
    """Returns the feature service names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The feature service names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_feature_services()]
get_feature_views(self)

Returns the feature view names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The feature view names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_views(self) -> List[str]:
    """Returns the feature view names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The feature view names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_feature_views()]
get_historical_features(self, entity_df, features, full_feature_names=False)

Returns the historical features for training or batch scoring.

Parameters:

Name Type Description Default
entity_df Union[pandas.core.frame.DataFrame, str]

The entity DataFrame or entity name.

required
features List[str]

The features to retrieve.

required
full_feature_names bool

Whether to return the full feature names.

False

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
DataFrame

The historical features as a Pandas DataFrame.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_historical_features(
    self,
    entity_df: Union[pd.DataFrame, str],
    features: List[str],
    full_feature_names: bool = False,
) -> pd.DataFrame:
    """Returns the historical features for training or batch scoring.

    Args:
        entity_df: The entity DataFrame or entity name.
        features: The features to retrieve.
        full_feature_names: Whether to return the full feature names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The historical features as a Pandas DataFrame.
    """
    fs = FeatureStore(repo_path=self.feast_repo)

    return fs.get_historical_features(
        entity_df=entity_df,
        features=features,
        full_feature_names=full_feature_names,
    ).to_df()
get_online_features(self, entity_rows, features, full_feature_names=False)

Returns the latest online feature data.

Parameters:

Name Type Description Default
entity_rows List[Dict[str, Any]]

The entity rows to retrieve.

required
features List[str]

The features to retrieve.

required
full_feature_names bool

Whether to return the full feature names.

False

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
Dict[str, Any]

The latest online feature data as a dictionary.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_online_features(
    self,
    entity_rows: List[Dict[str, Any]],
    features: List[str],
    full_feature_names: bool = False,
) -> Dict[str, Any]:
    """Returns the latest online feature data.

    Args:
        entity_rows: The entity rows to retrieve.
        features: The features to retrieve.
        full_feature_names: Whether to return the full feature names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The latest online feature data as a dictionary.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)

    return fs.get_online_features(  # type: ignore[no-any-return]
        entity_rows=entity_rows,
        features=features,
        full_feature_names=full_feature_names,
    ).to_dict()
get_project(self)

Returns the project name.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
str

The project name.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_project(self) -> str:
    """Returns the project name.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The project name.
    """
    fs = FeatureStore(repo_path=self.feast_repo)
    return str(fs.project)
get_registry(self)

Returns the feature store registry.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
Registry

The registry.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_registry(self) -> Registry:
    """Returns the feature store registry.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The registry.
    """
    fs: FeatureStore = FeatureStore(repo_path=self.feast_repo)
    return fs.registry

gcp special

Initialization of the GCP ZenML integration.

The GCP integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores, metadata stores, and an io module to handle file operations on Google Cloud Storage (GCS).

Additionally, the GCP secrets manager integration submodule provides a way to access the GCP secrets manager from within your ZenML Pipeline runs.

The Vertex AI integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.

GcpIntegration (Integration)

Definition of Google Cloud Platform integration for ZenML.

Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
    """Definition of Google Cloud Platform integration for ZenML."""

    NAME = GCP
    REQUIREMENTS = [
        "kfp==1.8.9",
        "gcsfs",
        "google-cloud-secret-manager",
        "google-cloud-aiplatform>=1.11.0",
    ]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the GCP integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=GCP_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.gcp.artifact_stores"
                ".GCPArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GCP_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.gcp.secrets_manager."
                "GCPSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GCP_VERTEX_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.gcp.orchestrators"
                ".VertexOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GCP_VERTEX_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.gcp.step_operators"
                ".VertexStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the GCP integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/gcp/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the GCP integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=GCP_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.gcp.artifact_stores"
            ".GCPArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GCP_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.gcp.secrets_manager."
            "GCPSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GCP_VERTEX_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.gcp.orchestrators"
            ".VertexOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GCP_VERTEX_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.gcp.step_operators"
            ".VertexStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

artifact_stores special

Initialization of the GCP Artifact Store.

gcp_artifact_store

Implementation of the GCP Artifact Store.

GCPArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for Google Cloud Storage based artifacts.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for Google Cloud Storage based artifacts."""

    _filesystem: Optional[gcsfs.GCSFileSystem] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = GCP_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX}

    @property
    def filesystem(self) -> gcsfs.GCSFileSystem:
        """The gcsfs filesystem to access this artifact store.

        Returns:
            The gcsfs filesystem to access this artifact store.
        """
        if not self._filesystem:
            secret = self.get_authentication_secret(
                expected_schema_type=GCPSecretSchema
            )
            token = secret.get_credential_dict() if secret else None
            self._filesystem = gcsfs.GCSFileSystem(token=token)

        return self._filesystem

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object that can be used to read or write to the file.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )
        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
          in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        return [
            f"{GCP_PATH_PREFIX}{path}"
            for path in self.filesystem.glob(path=pattern)
        ]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path of the directory to list.

        Returns:
            A list of paths of files in the directory.
        """
        path_without_prefix = convert_to_str(path)
        if path_without_prefix.startswith(GCP_PATH_PREFIX):
            path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a file info dict returned by GCP.

            Args:
                file_dict: A file info dict returned by the GCP filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["name"])
            base_name = file_path[len(path_without_prefix) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path of the directory to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path of the directory to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path of the file to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: the path to get stat info for.

        Returns:
            A dictionary with the stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        for (
            directory,
            subdirectories,
            files,
        ) in self.filesystem.walk(path=top):
            yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
filesystem: GCSFileSystem property readonly

The gcsfs filesystem to access this artifact store.

Returns:

Type Description
GCSFileSystem

The gcsfs filesystem to access this artifact store.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )
    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
      in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    return [
        f"{GCP_PATH_PREFIX}{path}"
        for path in self.filesystem.glob(path=pattern)
    ]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths of files in the directory.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path of the directory to list.

    Returns:
        A list of paths of files in the directory.
    """
    path_without_prefix = convert_to_str(path)
    if path_without_prefix.startswith(GCP_PATH_PREFIX):
        path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a file info dict returned by GCP.

        Args:
            file_dict: A file info dict returned by the GCP filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["name"])
        base_name = file_path[len(path_without_prefix) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to create.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path of the directory to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to create.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path of the directory to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object that can be used to read or write to the file.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object that can be used to read or write to the file.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the file to remove.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path of the file to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

the path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

A dictionary with the stat info.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: the path to get stat info for.

    Returns:
        A dictionary with the stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    for (
        directory,
        subdirectories,
        files,
    ) in self.filesystem.walk(path=top):
        yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files

constants

Constants for the VertexAI integration.

google_credentials_mixin

Implementation of the Google credentials mixin.

GoogleCredentialsMixin (BaseModel) pydantic-model

Mixin for Google Cloud Platform credentials.

Attributes:

Name Type Description
service_account_path Optional[str]

path to the service account credentials file to be used for authentication. If not provided, the default credentials will be used.

Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsMixin(BaseModel):
    """Mixin for Google Cloud Platform credentials.

    Attributes:
        service_account_path: path to the service account credentials file to be
            used for authentication. If not provided, the default credentials
            will be used.
    """

    service_account_path: Optional[str] = None

    def _get_authentication(self) -> Tuple["Credentials", str]:
        """Get GCP credentials and the project ID associated with the credentials.

        If `service_account_path` is provided, then the credentials will be
        loaded from the file at that path. Otherwise, the default credentials
        will be used.

        Returns:
            A tuple containing the credentials and the project ID associated to
            the credentials.
        """
        if self.service_account_path:
            credentials, project_id = load_credentials_from_file(
                self.service_account_path
            )
        else:
            credentials, project_id = default()
        return credentials, project_id

orchestrators special

Initialization for the VertexAI orchestrator.

vertex_entrypoint_configuration

Implementation of the VertexAI entrypoint configuration.

VertexEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on Vertex AI Pipelines.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
class VertexEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on Vertex AI Pipelines."""

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """Vertex AI Pipelines specific entrypoint options.

        The argument `VERTEX_JOB_ID_OPTION` allows to specify the job id of the
        Vertex AI Pipeline and get it in the execution of the step, via the `get_run_name`
        method.

        Returns:
            The set of custom entrypoint options.
        """
        return {VERTEX_JOB_ID_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: "BaseStep", *args: Any, **kwargs: Any
    ) -> List[str]:
        """Sets the value for the `VERTEX_JOB_ID_OPTION` argument.

        Args:
            step: The step to be executed.
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            A list of arguments for the entrypoint.
        """
        return [f"--{VERTEX_JOB_ID_OPTION}", kwargs[VERTEX_JOB_ID_OPTION]]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the Vertex AI Pipeline job id.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The Vertex AI Pipeline job id.
        """
        job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
        return job_id
get_custom_entrypoint_arguments(step, *args, **kwargs) classmethod

Sets the value for the VERTEX_JOB_ID_OPTION argument.

Parameters:

Name Type Description Default
step BaseStep

The step to be executed.

required
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[str]

A list of arguments for the entrypoint.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
    """Sets the value for the `VERTEX_JOB_ID_OPTION` argument.

    Args:
        step: The step to be executed.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        A list of arguments for the entrypoint.
    """
    return [f"--{VERTEX_JOB_ID_OPTION}", kwargs[VERTEX_JOB_ID_OPTION]]
get_custom_entrypoint_options() classmethod

Vertex AI Pipelines specific entrypoint options.

The argument VERTEX_JOB_ID_OPTION allows to specify the job id of the Vertex AI Pipeline and get it in the execution of the step, via the get_run_name method.

Returns:

Type Description
Set[str]

The set of custom entrypoint options.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """Vertex AI Pipelines specific entrypoint options.

    The argument `VERTEX_JOB_ID_OPTION` allows to specify the job id of the
    Vertex AI Pipeline and get it in the execution of the step, via the `get_run_name`
    method.

    Returns:
        The set of custom entrypoint options.
    """
    return {VERTEX_JOB_ID_OPTION}
get_run_name(self, pipeline_name)

Returns the Vertex AI Pipeline job id.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The Vertex AI Pipeline job id.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the Vertex AI Pipeline job id.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The Vertex AI Pipeline job id.
    """
    job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
    return job_id
vertex_orchestrator

Implementation of the VertexAI orchestrator.

VertexOrchestrator (BaseOrchestrator, GoogleCredentialsMixin) pydantic-model

Orchestrator responsible for running pipelines on Vertex AI.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of the Docker image that should be used as the base for the image that will be used to execute each of the steps. If no custom base image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML Docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

project Optional[str]

GCP project name. If None, the project will be inferred from the environment.

location str

Name of GCP region where the pipeline job will be executed. Vertex AI Pipelines is available in the following regions: https://cloud.google.com/vertex-ai/docs/general/locations#feature -availability

pipeline_root Optional[str]

a Cloud Storage URI that will be used by the Vertex AI

encryption_spec_key_name Optional[str]

The Cloud KMS resource identifier of the

customer managed encryption key used to protect the job. Has the form

projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY> . The key needs to be in the same region as where the compute resource is created.

workload_service_account Optional[str]

the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If not provided, the default service account will be used.

network Optional[str]

the full name of the Compute Engine Network to which the job

synchronous bool

If True, running a pipeline using this orchestrator will block until all steps finished running on Vertex AI Pipelines service.

cpu_limit Optional[str]

The maximum CPU limit for this operator. This string value can be a number (integer value for number of CPUs) as string, or a number followed by "m", which means 1/1000. You can specify at most 96 CPUs. (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)

memory_limit Optional[str]

The maximum memory limit for this operator. This string value can be a number, or a number followed by "K" (kilobyte), "M" (megabyte), or "G" (gigabyte). At most 624GB is supported.

node_selector_constraint Optional[Tuple[str, str]]

Each constraint is a key-value pair label. For the container to be eligible to run on a node, the node must have each of the constraints appeared as labels. For example a GPU type can be providing by one of the following tuples: - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100") Hint: the selected region (location) must provide the requested accelerator (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).

gpu_limit Optional[int]

The GPU limit (positive number) for the operator. For more information about GPU resources, see: https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
class VertexOrchestrator(BaseOrchestrator, GoogleCredentialsMixin):
    """Orchestrator responsible for running pipelines on Vertex AI.

    Attributes:
        custom_docker_base_image_name: Name of the Docker image that should be
            used as the base for the image that will be used to execute each of
            the steps. If no custom base image is given, a basic image of the
            active ZenML version will be used. **Note**: This image needs to
            have ZenML installed, otherwise the pipeline execution will fail.
            For that reason, you might want to extend the ZenML Docker images found
            here: https://hub.docker.com/r/zenmldocker/zenml/
        project: GCP project name. If `None`, the project will be inferred from
            the environment.
        location: Name of GCP region where the pipeline job will be executed.
            Vertex AI Pipelines is available in the following regions:
            https://cloud.google.com/vertex-ai/docs/general/locations#feature
            -availability
        pipeline_root: a Cloud Storage URI that will be used by the Vertex AI
        Pipelines.
            If not provided but the artifact store in the stack used to execute
            the pipeline is a
            `zenml.integrations.gcp.artifact_stores.GCPArtifactStore`,
            then a subdirectory of the artifact store will be used.
        encryption_spec_key_name: The Cloud KMS resource identifier of the
        customer
            managed encryption key used to protect the job. Has the form:
            `projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY>`
            . The key needs to be in the same region as where the compute
            resource is created.
        workload_service_account: the service account for workload run-as
            account. Users submitting jobs must have act-as permission on this
            run-as account.
            If not provided, the default service account will be used.
        network: the full name of the Compute Engine Network to which the job
        should
            be peered. For example, `projects/12345/global/networks/myVPC`
            If not provided, the job will not be peered with any network.
        synchronous: If `True`, running a pipeline using this orchestrator will
            block until all steps finished running on Vertex AI Pipelines
            service.
        cpu_limit: The maximum CPU limit for this operator. This string value
            can be a number (integer value for number of CPUs) as string,
            or a number followed by "m", which means 1/1000. You can specify
            at most 96 CPUs.
            (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)
        memory_limit: The maximum memory limit for this operator. This string
            value can be a number, or a number followed by "K" (kilobyte),
            "M" (megabyte), or "G" (gigabyte). At most 624GB is supported.
        node_selector_constraint: Each constraint is a key-value pair label.
            For the container to be eligible to run on a node, the node must have
            each of the constraints appeared as labels.
            For example a GPU type can be providing by one of the following tuples:
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100")
            Hint: the selected region (location) must provide the requested accelerator
            (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).
        gpu_limit: The GPU limit (positive number) for the operator.
            For more information about GPU resources, see:
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
    """

    custom_docker_base_image_name: Optional[str] = None
    project: Optional[str] = None
    location: str
    pipeline_root: Optional[str] = None
    labels: Dict[str, str] = {}
    encryption_spec_key_name: Optional[str] = None
    workload_service_account: Optional[str] = None
    network: Optional[str] = None
    synchronous: bool = False

    cpu_limit: Optional[str] = None
    memory_limit: Optional[str] = None
    node_selector_constraint: Optional[Tuple[str, str]] = None
    gpu_limit: Optional[int] = None

    _pipeline_root: str

    FLAVOR: ClassVar[str] = GCP_VERTEX_ORCHESTRATOR_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Also validates that the artifact store and metadata store used are not
        local.

        Returns:
            A StackValidator instance.
        """

        def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that all the stack components are not local.

            Args:
                stack: The stack to validate.

            Returns:
                A tuple of (is_valid, error_message).
            """
            # Validate that the container registry is not local.
            container_registry = stack.container_registry
            if container_registry and container_registry.is_local:
                return False, (
                    f"The Vertex orchestrator does not support local "
                    f"container registries. You should replace the component '"
                    f"{container_registry.name}' "
                    f"{container_registry.TYPE.value} to a remote one."
                )

            # Validate that the rest of the components are not local.
            for stack_comp in stack.components.values():
                local_path = stack_comp.local_path
                if not local_path:
                    continue
                return False, (
                    f"The '{stack_comp.name}' {stack_comp.TYPE.value} is a "
                    f"local stack component. The Vertex AI Pipelines "
                    f"orchestrator requires that all the components in the "
                    f"stack used to execute the pipeline have to be not local, "
                    f"because there is no way for Vertex to connect to your "
                    f"local machine. You should use a flavor of "
                    f"{stack_comp.TYPE.value} other than '"
                    f"{stack_comp.FLAVOR}'."
                )

            # If the `pipeline_root` has not been defined in the orchestrator
            # configuration, and the artifact store is not a GCP artifact store,
            # then raise an error.
            if (
                not self.pipeline_root
                and stack.artifact_store.FLAVOR != GCP_ARTIFACT_STORE_FLAVOR
            ):
                return False, (
                    f"The attribute `pipeline_root` has not been set and it "
                    f"cannot be generated using the path of the artifact store "
                    f"because it is not a "
                    f"`zenml.integrations.gcp.artifact_store.GCPArtifactStore`."
                    f" To solve this issue, set the `pipeline_root` attribute "
                    f"manually executing the following command: "
                    f"`zenml orchestrator update {stack.orchestrator.name} "
                    f'--pipeline_root="<Cloud Storage URI>"`.'
                )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_stack_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The full docker image name including registry and tag.
        """
        base_image_name = f"zenml-vertex:{pipeline_name}"
        container_registry = Repository().active_stack.container_registry

        if container_registry:
            registry_uri = container_registry.uri.rstrip("/")
            return f"{registry_uri}/{base_image_name}"

        return base_image_name

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for files for this orchestrator.

        Returns:
            The path to the root directory for all files concerning this
            orchestrator.
        """
        return os.path.join(
            get_global_config_directory(), "vertex", str(self.uuid)
        )

    @property
    def pipeline_directory(self) -> str:
        """Returns path to directory where kubeflow pipelines files are stored.

        Returns:
            Path to the pipeline directory.
        """
        return os.path.join(self.root_directory, "pipelines")

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Build a Docker image for the current environment.

        This uploads it to a container registry if configured.

        Args:
            pipeline: The pipeline to be deployed.
            stack: The stack that will be used to deploy the pipeline.
            runtime_configuration: The runtime configuration for the pipeline.

        Raises:
            RuntimeError: If the container registry is missing.
        """
        from zenml.utils import docker_utils

        repo = Repository()
        container_registry = repo.active_stack.container_registry

        if not container_registry:
            raise RuntimeError("Missing container registry")

        image_name = self.get_docker_image_name(pipeline.name)

        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug(
            "Vertex AI Pipelines service docker container requirements %s",
            requirements,
        )

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )
        container_registry.push_image(image_name)

    def _configure_container_resources(
        self,
        container_op: dsl.ContainerOp,
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Adds resource requirements to the container.

        Args:
            container_op: The kubeflow container operation to configure.
            resource_configuration: The resource configuration to use for this
                container.
        """
        # Set optional CPU, RAM and GPU constraints for the pipeline
        cpu_limit = resource_configuration.cpu_count or self.cpu_limit
        if cpu_limit is not None:
            container_op = container_op.set_cpu_limit(str(cpu_limit))

        memory_limit = (
            resource_configuration.memory[:-1]
            if resource_configuration.memory
            else self.memory_limit
        )
        if memory_limit is not None:
            container_op = container_op.set_memory_limit(memory_limit)

        if self.node_selector_constraint is not None:
            container_op = container_op.add_node_selector_constraint(
                label_name=self.node_selector_constraint[0],
                value=self.node_selector_constraint[1],
            )

        gpu_limit = resource_configuration.gpu_count or self.gpu_limit
        if gpu_limit is not None:
            container_op = container_op.set_gpu_limit(gpu_limit)

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List["BaseStep"],
        pipeline: "BasePipeline",
        pb2_pipeline: "Pb2Pipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates a KFP JSON pipeline.

        # noqa: DAR402

        This is an intermediary representation of the pipeline which is then
        deployed to Vertex AI Pipelines service.

        How it works:
        -------------
        Before this method is called the `prepare_pipeline_deployment()` method
        builds a Docker image that contains the code for the pipeline, all steps
        the context around these files.

        Based on this Docker image a callable is created which builds
        container_ops for each step (`_construct_kfp_pipeline`). The function
        `kfp.components.load_component_from_text` is used to create the
        `ContainerOp`, because using the `dsl.ContainerOp` class directly is
        deprecated when using the Kubeflow SDK v2. The step entrypoint command
        with the entrypoint arguments is the command that will be executed by
        the container created using the previously created Docker image.

        This callable is then compiled into a JSON file that is used as the
        intermediary representation of the Kubeflow pipeline.

        This file then is submitted to the Vertex AI Pipelines service for
        execution.

        Args:
            sorted_steps: List of sorted steps.
            pipeline: Zenml Pipeline instance.
            pb2_pipeline: Protobuf Pipeline instance.
            stack: The stack the pipeline was run on.
            runtime_configuration: The Runtime configuration of the current run.

        Raises:
            ValueError: If the attribute `pipeline_root` is not set and it
                can be not generated using the path of the artifact store in the
                stack because it is not a
                `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
        """
        # If the `pipeline_root` has not been defined in the orchestrator
        # configuration,
        # try to create it from the artifact store if it is a
        # `GCPArtifactStore`.
        if not self.pipeline_root:
            artifact_store = stack.artifact_store
            self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{pipeline.name}/{runtime_configuration.run_name}"
            logger.info(
                "The attribute `pipeline_root` has not been set in the "
                "orchestrator configuration. One has been generated "
                "automatically based on the path of the `GCPArtifactStore` "
                "artifact store in the stack used to execute the pipeline. "
                "The generated `pipeline_root` is `%s`.",
                self._pipeline_root,
            )
        else:
            self._pipeline_root = self.pipeline_root

        # Build the Docker image that will be used to run the steps of the
        # pipeline.
        image_name = self.get_docker_image_name(pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        def _construct_kfp_pipeline() -> None:
            """Create a `ContainerOp` for each step.

            This should contain the name of the Docker image and configures the
            entrypoint of the Docker image to run the step.

            Additionally, this gives each `ContainerOp` information about its
            direct downstream steps.

            If this callable is passed to the `compile()` method of
            `KFPV2Compiler` all `dsl.ContainerOp` instances will be
            automatically added to a singular `dsl.Pipeline` instance.
            """
            step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

            for step in sorted_steps:
                # The command will be needed to eventually call the python step
                # within the docker container
                command = VertexEntrypointConfiguration.get_entrypoint_command()

                # The arguments are passed to configure the entrypoint of the
                # docker container when the step is called.
                arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                    **{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
                )

                # Create the `ContainerOp` for the step. Using the
                # `dsl.ContainerOp`
                # class directly is deprecated when using the Kubeflow SDK v2.
                container_op = kfp.components.load_component_from_text(
                    f"""
                    name: {step.name}
                    implementation:
                        container:
                            image: {image_name}
                            command: {command + arguments}"""
                )()

                # Set upstream tasks as a dependency of the current step
                upstream_step_names = self.get_upstream_step_names(
                    step=step, pb2_pipeline=pb2_pipeline
                )
                for upstream_step_name in upstream_step_names:
                    upstream_container_op = step_name_to_container_op[
                        upstream_step_name
                    ]
                    container_op.after(upstream_container_op)

                self._configure_container_resources(
                    container_op=container_op,
                    resource_configuration=step.resource_configuration,
                )

                step_name_to_container_op[step.name] = container_op

        # Save the generated pipeline to a file.
        assert runtime_configuration.run_name
        fileio.makedirs(self.pipeline_directory)
        pipeline_file_path = os.path.join(
            self.pipeline_directory,
            f"{runtime_configuration.run_name}.json",
        )

        # Compile the pipeline using the Kubeflow SDK V2 compiler that allows
        # to generate a JSON representation of the pipeline that can be later
        # upload to Vertex AI Pipelines service.
        logger.debug(
            "Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
            "to `%s`",
            pipeline_file_path,
        )
        KFPV2Compiler().compile(
            pipeline_func=_construct_kfp_pipeline,
            package_path=pipeline_file_path,
            pipeline_name=_clean_pipeline_name(pipeline.name),
        )

        # Using the Google Cloud AIPlatform client, upload and execute the
        # pipeline
        # on the Vertex AI Pipelines service.
        self._upload_and_run_pipeline(
            pipeline_name=pipeline.name,
            pipeline_file_path=pipeline_file_path,
            runtime_configuration=runtime_configuration,
            enable_cache=pipeline.enable_cache,
        )

    def _upload_and_run_pipeline(
        self,
        pipeline_name: str,
        pipeline_file_path: str,
        runtime_configuration: "RuntimeConfiguration",
        enable_cache: bool,
    ) -> None:
        """Uploads and run the pipeline on the Vertex AI Pipelines service.

        Args:
            pipeline_name: Name of the pipeline.
            pipeline_file_path: Path of the JSON file containing the compiled
                Kubeflow pipeline (compiled with Kubeflow SDK v2).
            runtime_configuration: Runtime configuration of the pipeline run.
            enable_cache: Whether caching is enabled for this pipeline run.
        """
        # We have to replace the hyphens in the pipeline name with underscores
        # and lower case the string, because the Vertex AI Pipelines service
        # requires this format.
        assert runtime_configuration.run_name
        job_id = _clean_pipeline_name(runtime_configuration.run_name)

        # Warn the user that the scheduling is not available using the Vertex
        # Orchestrator
        if runtime_configuration.schedule:
            logger.warning(
                "Pipeline scheduling configuration was provided, but Vertex "
                "AI Pipelines "
                "do not have capabilities for scheduling yet."
            )

        # Get the credentials that would be used to create the Vertex AI
        # Pipelines
        # job.
        credentials, project_id = self._get_authentication()
        if self.project and self.project != project_id:
            logger.warning(
                "Authenticated with project `%s`, but this orchestrator is "
                "configured to use the project `%s`.",
                project_id,
                self.project,
            )

        # If the project was set in the configuration, use it. Otherwise, use
        # the project that was used to authenticate.
        project_id = self.project if self.project else project_id

        # Instantiate the Vertex AI Pipelines job
        run = aiplatform.PipelineJob(
            display_name=pipeline_name,
            template_path=pipeline_file_path,
            job_id=job_id,
            pipeline_root=self._pipeline_root,
            parameter_values=None,
            enable_caching=enable_cache,
            encryption_spec_key_name=self.encryption_spec_key_name,
            labels=self.labels,
            credentials=credentials,
            project=self.project,
            location=self.location,
        )

        logger.info(
            "Submitting pipeline job with job_id `%s` to Vertex AI Pipelines "
            "service.",
            job_id,
        )

        # Submit the job to Vertex AI Pipelines service.
        try:
            if self.workload_service_account:
                logger.info(
                    "The Vertex AI Pipelines job workload will be executed "
                    "using `%s` "
                    "service account.",
                    self.workload_service_account,
                )

            if self.network:
                logger.info(
                    "The Vertex AI Pipelines job will be peered with `%s` "
                    "network.",
                    self.network,
                )

            run.submit(
                service_account=self.workload_service_account,
                network=self.network,
            )
            logger.info(
                "View the Vertex AI Pipelines job at %s", run._dashboard_uri()
            )

            if self.synchronous:
                logger.info(
                    "Waiting for the Vertex AI Pipelines job to finish..."
                )
                run.wait()

        except google_exceptions.ClientError as e:
            logger.warning(
                "Failed to create the Vertex AI Pipelines job: %s", e
            )

        except RuntimeError as e:
            logger.error(
                "The Vertex AI Pipelines job execution has failed: %s", e
            )
pipeline_directory: str property readonly

Returns path to directory where kubeflow pipelines files are stored.

Returns:

Type Description
str

Path to the pipeline directory.

root_directory: str property readonly

Returns path to the root directory for files for this orchestrator.

Returns:

Type Description
str

The path to the root directory for all files concerning this orchestrator.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Also validates that the artifact store and metadata store used are not local.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The full docker image name including registry and tag.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The full docker image name including registry and tag.
    """
    base_image_name = f"zenml-vertex:{pipeline_name}"
    container_registry = Repository().active_stack.container_registry

    if container_registry:
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/{base_image_name}"

    return base_image_name
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates a KFP JSON pipeline.

noqa: DAR402

This is an intermediary representation of the pipeline which is then deployed to Vertex AI Pipelines service.

How it works:

Before this method is called the prepare_pipeline_deployment() method builds a Docker image that contains the code for the pipeline, all steps the context around these files.

Based on this Docker image a callable is created which builds container_ops for each step (_construct_kfp_pipeline). The function kfp.components.load_component_from_text is used to create the ContainerOp, because using the dsl.ContainerOp class directly is deprecated when using the Kubeflow SDK v2. The step entrypoint command with the entrypoint arguments is the command that will be executed by the container created using the previously created Docker image.

This callable is then compiled into a JSON file that is used as the intermediary representation of the Kubeflow pipeline.

This file then is submitted to the Vertex AI Pipelines service for execution.

Parameters:

Name Type Description Default
sorted_steps List[BaseStep]

List of sorted steps.

required
pipeline BasePipeline

Zenml Pipeline instance.

required
pb2_pipeline Pb2Pipeline

Protobuf Pipeline instance.

required
stack Stack

The stack the pipeline was run on.

required
runtime_configuration RuntimeConfiguration

The Runtime configuration of the current run.

required

Exceptions:

Type Description
ValueError

If the attribute pipeline_root is not set and it can be not generated using the path of the artifact store in the stack because it is not a zenml.integrations.gcp.artifact_store.GCPArtifactStore.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List["BaseStep"],
    pipeline: "BasePipeline",
    pb2_pipeline: "Pb2Pipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates a KFP JSON pipeline.

    # noqa: DAR402

    This is an intermediary representation of the pipeline which is then
    deployed to Vertex AI Pipelines service.

    How it works:
    -------------
    Before this method is called the `prepare_pipeline_deployment()` method
    builds a Docker image that contains the code for the pipeline, all steps
    the context around these files.

    Based on this Docker image a callable is created which builds
    container_ops for each step (`_construct_kfp_pipeline`). The function
    `kfp.components.load_component_from_text` is used to create the
    `ContainerOp`, because using the `dsl.ContainerOp` class directly is
    deprecated when using the Kubeflow SDK v2. The step entrypoint command
    with the entrypoint arguments is the command that will be executed by
    the container created using the previously created Docker image.

    This callable is then compiled into a JSON file that is used as the
    intermediary representation of the Kubeflow pipeline.

    This file then is submitted to the Vertex AI Pipelines service for
    execution.

    Args:
        sorted_steps: List of sorted steps.
        pipeline: Zenml Pipeline instance.
        pb2_pipeline: Protobuf Pipeline instance.
        stack: The stack the pipeline was run on.
        runtime_configuration: The Runtime configuration of the current run.

    Raises:
        ValueError: If the attribute `pipeline_root` is not set and it
            can be not generated using the path of the artifact store in the
            stack because it is not a
            `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
    """
    # If the `pipeline_root` has not been defined in the orchestrator
    # configuration,
    # try to create it from the artifact store if it is a
    # `GCPArtifactStore`.
    if not self.pipeline_root:
        artifact_store = stack.artifact_store
        self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{pipeline.name}/{runtime_configuration.run_name}"
        logger.info(
            "The attribute `pipeline_root` has not been set in the "
            "orchestrator configuration. One has been generated "
            "automatically based on the path of the `GCPArtifactStore` "
            "artifact store in the stack used to execute the pipeline. "
            "The generated `pipeline_root` is `%s`.",
            self._pipeline_root,
        )
    else:
        self._pipeline_root = self.pipeline_root

    # Build the Docker image that will be used to run the steps of the
    # pipeline.
    image_name = self.get_docker_image_name(pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    def _construct_kfp_pipeline() -> None:
        """Create a `ContainerOp` for each step.

        This should contain the name of the Docker image and configures the
        entrypoint of the Docker image to run the step.

        Additionally, this gives each `ContainerOp` information about its
        direct downstream steps.

        If this callable is passed to the `compile()` method of
        `KFPV2Compiler` all `dsl.ContainerOp` instances will be
        automatically added to a singular `dsl.Pipeline` instance.
        """
        step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

        for step in sorted_steps:
            # The command will be needed to eventually call the python step
            # within the docker container
            command = VertexEntrypointConfiguration.get_entrypoint_command()

            # The arguments are passed to configure the entrypoint of the
            # docker container when the step is called.
            arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
                step=step,
                pb2_pipeline=pb2_pipeline,
                **{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
            )

            # Create the `ContainerOp` for the step. Using the
            # `dsl.ContainerOp`
            # class directly is deprecated when using the Kubeflow SDK v2.
            container_op = kfp.components.load_component_from_text(
                f"""
                name: {step.name}
                implementation:
                    container:
                        image: {image_name}
                        command: {command + arguments}"""
            )()

            # Set upstream tasks as a dependency of the current step
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                upstream_container_op = step_name_to_container_op[
                    upstream_step_name
                ]
                container_op.after(upstream_container_op)

            self._configure_container_resources(
                container_op=container_op,
                resource_configuration=step.resource_configuration,
            )

            step_name_to_container_op[step.name] = container_op

    # Save the generated pipeline to a file.
    assert runtime_configuration.run_name
    fileio.makedirs(self.pipeline_directory)
    pipeline_file_path = os.path.join(
        self.pipeline_directory,
        f"{runtime_configuration.run_name}.json",
    )

    # Compile the pipeline using the Kubeflow SDK V2 compiler that allows
    # to generate a JSON representation of the pipeline that can be later
    # upload to Vertex AI Pipelines service.
    logger.debug(
        "Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
        "to `%s`",
        pipeline_file_path,
    )
    KFPV2Compiler().compile(
        pipeline_func=_construct_kfp_pipeline,
        package_path=pipeline_file_path,
        pipeline_name=_clean_pipeline_name(pipeline.name),
    )

    # Using the Google Cloud AIPlatform client, upload and execute the
    # pipeline
    # on the Vertex AI Pipelines service.
    self._upload_and_run_pipeline(
        pipeline_name=pipeline.name,
        pipeline_file_path=pipeline_file_path,
        runtime_configuration=runtime_configuration,
        enable_cache=pipeline.enable_cache,
    )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Build a Docker image for the current environment.

This uploads it to a container registry if configured.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to be deployed.

required
stack Stack

The stack that will be used to deploy the pipeline.

required
runtime_configuration RuntimeConfiguration

The runtime configuration for the pipeline.

required

Exceptions:

Type Description
RuntimeError

If the container registry is missing.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Build a Docker image for the current environment.

    This uploads it to a container registry if configured.

    Args:
        pipeline: The pipeline to be deployed.
        stack: The stack that will be used to deploy the pipeline.
        runtime_configuration: The runtime configuration for the pipeline.

    Raises:
        RuntimeError: If the container registry is missing.
    """
    from zenml.utils import docker_utils

    repo = Repository()
    container_registry = repo.active_stack.container_registry

    if not container_registry:
        raise RuntimeError("Missing container registry")

    image_name = self.get_docker_image_name(pipeline.name)

    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug(
        "Vertex AI Pipelines service docker container requirements %s",
        requirements,
    )

    docker_utils.build_docker_image(
        build_context_path=get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )
    container_registry.push_image(image_name)

secrets_manager special

ZenML integration for GCP Secrets Manager.

The GCP Secrets Manager allows your pipeline to directly access the GCP secrets manager and use the secrets within during runtime.

gcp_secrets_manager

Implementation of the GCP Secrets Manager.

GCPSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the GCP secrets manager.

Attributes:

Name Type Description
project_id str

This is necessary to access the correct GCP project. The project_id of your GCP project space that contains the Secret Manager.

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
class GCPSecretsManager(BaseSecretsManager):
    """Class to interact with the GCP secrets manager.

    Attributes:
        project_id:  This is necessary to access the correct GCP project.
                     The project_id of your GCP project space that contains
                     the Secret Manager.
    """

    project_id: str

    # Class configuration
    FLAVOR: ClassVar[str] = GCP_SECRETS_MANAGER_FLAVOR
    SUPPORTS_SCOPING: ClassVar[bool] = True
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _ensure_client_connected(cls) -> None:
        if cls.CLIENT is None:
            cls.CLIENT = secretmanager.SecretManagerServiceClient()

    @classmethod
    def _validate_scope(
        cls,
        scope: SecretsManagerScope,
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            cls.validate_secret_name_or_namespace(namespace)

    @classmethod
    def validate_secret_name_or_namespace(
        cls,
        name: str,
    ) -> None:
        """Validate a secret name or namespace.

        A Google secret ID is a string with a maximum length of 255 characters
        and can contain uppercase and lowercase letters, numerals, and the
        hyphen (-) and underscore (_) characters. For scoped secrets, we have to
        limit the size of the name and namespace even further to allow space for
        both in the Google secret ID.

        Given that we also save secret names and namespaces as labels, we are
        also limited by the limitation that Google imposes on label values: max
        63 characters and must and must only contain lowercase letters, numerals
        and the hyphen (-) and underscore (_) characters

        Args:
            name: the secret name or namespace

        Raises:
            ValueError: if the secret name or namespace is invalid
        """
        if not re.fullmatch(r"[a-z0-9_\-]+", name):
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. Must contain "
                f"only lowercase alphanumeric characters and the hyphen (-) and "
                f"underscore (_) characters."
            )

        if name and len(name) > 63:
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. The length is "
                f"limited to maximum 63 characters."
            )

    @property
    def parent_name(self) -> str:
        """Construct the GCP parent path to the secret manager.

        Returns:
            The parent path to the secret manager
        """
        return f"projects/{self.project_id}"

    def _convert_secret_content(
        self, secret: BaseSecretSchema
    ) -> Dict[str, str]:
        """Convert the secret content into a Google compatible representation.

        This method implements two currently supported modes of adapting between
        the naming schemas used for ZenML secrets and Google secrets:

        * for a scoped Secrets Manager, a Google secret is created for each
        ZenML secret with a name that reflects the ZenML secret name and scope
        and a value that contains all its key-value pairs in JSON format.

        * for an unscoped (i.e. legacy) Secrets Manager, this method creates
        multiple Google secret entries for a single ZenML secret by adding the
        secret name to the key name of each secret key-value pair. This allows
        using the same key across multiple secrets. This is only kept for
        backwards compatibility and will be removed some time in the future.

        Args:
            secret: The ZenML secret

        Returns:
            A dictionary with the Google secret name as key and the secret
            contents as value.
        """
        if self.scope == SecretsManagerScope.NONE:
            # legacy per-key secret mapping
            return {f"{secret.name}_{k}": v for k, v in secret.content.items()}

        return {
            self._get_scoped_secret_name(
                secret.name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
            ): json.dumps(secret_to_dict(secret)),
        }

    def _get_secret_labels(
        self, secret: BaseSecretSchema
    ) -> List[Tuple[str, str]]:
        """Return a list of Google secret label values for a given secret.

        Args:
            secret: the secret object

        Returns:
            A list of Google secret label values
        """
        if self.scope == SecretsManagerScope.NONE:
            # legacy per-key secret labels
            return [
                (ZENML_GROUP_KEY, secret.name),
                (ZENML_SCHEMA_NAME, secret.TYPE),
            ]

        metadata = self._get_secret_metadata(secret)
        return list(metadata.items())

    def _get_secret_scope_filters(
        self,
        secret_name: Optional[str] = None,
    ) -> str:
        """Return a Google filter expression for the entire scope or just a scoped secret.

        These filters can be used when querying the Google Secrets Manager
        for all secrets or for a single secret available in the configured
        scope (see https://cloud.google.com/secret-manager/docs/filtering).

        Args:
            secret_name: Optional secret name to include in the scope metadata.

        Returns:
            Google filter expression uniquely identifying all secrets
            or a named secret within the configured scope.
        """
        if self.scope == SecretsManagerScope.NONE:
            # legacy per-key secret label filters
            if secret_name:
                return f"labels.{ZENML_GROUP_KEY}={secret_name}"
            else:
                return f"labels.{ZENML_GROUP_KEY}:*"

        metadata = self._get_secret_scope_metadata(secret_name)
        filters = [f"labels.{l}={v}" for (l, v) in metadata.items()]
        if secret_name:
            filters.append(f"name:{secret_name}")

        return " AND ".join(filters)

    def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
        """List all secrets matching a name.

        This method lists all the secrets in the current scope without loading
        their contents. An optional secret name can be supplied to filter out
        all but a single secret identified by name.

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of secret names in the current scope and the optional
            secret name.
        """
        self._ensure_client_connected()

        set_of_secrets = set()

        # List all secrets.
        for secret in self.CLIENT.list_secrets(
            request={
                "parent": self.parent_name,
                "filter": self._get_secret_scope_filters(secret_name),
            }
        ):
            if self.scope == SecretsManagerScope.NONE:
                name = secret.labels[ZENML_GROUP_KEY]
            else:
                name = secret.labels[ZENML_SECRET_NAME_LABEL]

            # filter by secret name, if one was given
            if name and (not secret_name or name == secret_name):
                set_of_secrets.add(name)

        return list(set_of_secrets)

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self.validate_secret_name_or_namespace(secret.name)
        self._ensure_client_connected()

        if self._list_secrets(secret.name):
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        adjusted_content = self._convert_secret_content(secret)
        for k, v in adjusted_content.items():
            # Create the secret, this only creates an empty secret with the
            #  supplied name.
            gcp_secret = self.CLIENT.create_secret(
                request={
                    "parent": self.parent_name,
                    "secret_id": k,
                    "secret": {
                        "replication": {"automatic": {}},
                        "labels": self._get_secret_labels(secret),
                    },
                }
            )

            logger.debug("Created empty secret: %s", gcp_secret.name)

            self.CLIENT.add_secret_version(
                request={
                    "parent": gcp_secret.name,
                    "payload": {"data": str(v).encode()},
                }
            )

            logger.debug("Added value to secret.")

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Get a secret by its name.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name_or_namespace(secret_name)
        self._ensure_client_connected()

        zenml_secret: Optional[BaseSecretSchema] = None

        if self.scope == SecretsManagerScope.NONE:
            # Legacy secrets are mapped to multiple Google secrets, one for
            # each secret key

            secret_contents = {}
            zenml_schema_name = ""

            # List all secrets.
            for google_secret in self.CLIENT.list_secrets(
                request={
                    "parent": self.parent_name,
                    "filter": self._get_secret_scope_filters(secret_name),
                }
            ):
                secret_version_name = google_secret.name + "/versions/latest"

                response = self.CLIENT.access_secret_version(
                    request={"name": secret_version_name}
                )

                secret_value = response.payload.data.decode("UTF-8")

                secret_key = remove_group_name_from_key(
                    google_secret.name.split("/")[-1], secret_name
                )

                secret_contents[secret_key] = secret_value

                zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]

            if not secret_contents:
                raise KeyError(
                    f"Can't find the specified secret '{secret_name}'"
                )

            secret_contents["name"] = secret_name

            secret_schema = SecretSchemaClassRegistry.get_class(
                secret_schema=zenml_schema_name
            )
            zenml_secret = secret_schema(**secret_contents)

        else:
            # Scoped secrets are mapped 1-to-1 with Google secrets

            google_secret_name = self.CLIENT.secret_path(
                self.project_id,
                self._get_scoped_secret_name(
                    secret_name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
                ),
            )

            try:
                # fetch the latest secret version
                google_secret = self.CLIENT.get_secret(name=google_secret_name)
            except google_exceptions.NotFound:
                raise KeyError(
                    f"Can't find the specified secret '{secret_name}'"
                )

            # make sure the secret has the correct scope labels to filter out
            # unscoped secrets with similar names
            scope_labels = self._get_secret_scope_metadata(secret_name)
            # all scope labels need to be included in the google secret labels,
            # otherwise the secret does not belong to the current scope
            if not scope_labels.items() <= google_secret.labels.items():
                raise KeyError(
                    f"Can't find the specified secret '{secret_name}'"
                )

            try:
                # fetch the latest secret version
                response = self.CLIENT.access_secret_version(
                    name=f"{google_secret_name}/versions/latest"
                )
            except google_exceptions.NotFound:
                raise KeyError(
                    f"Can't find the specified secret '{secret_name}'"
                )

            secret_value = response.payload.data.decode("UTF-8")
            zenml_secret = secret_from_dict(
                json.loads(secret_value), secret_name=secret_name
            )

        return zenml_secret

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        return self._list_secrets()

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret by creating new versions of the existing secrets.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name_or_namespace(secret.name)
        self._ensure_client_connected()

        if not self._list_secrets(secret.name):
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        adjusted_content = self._convert_secret_content(secret)

        for k, v in adjusted_content.items():
            # Create the secret, this only creates an empty secret with the
            #  supplied name.
            google_secret_name = self.CLIENT.secret_path(self.project_id, k)
            payload = {"data": str(v).encode()}

            self.CLIENT.add_secret_version(
                request={"parent": google_secret_name, "payload": payload}
            )

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret by name.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret no longer exists
        """
        self.validate_secret_name_or_namespace(secret_name)
        self._ensure_client_connected()

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        # Go through all gcp secrets and delete the ones with the secret_name
        # as label.
        for secret in self.CLIENT.list_secrets(
            request={
                "parent": self.parent_name,
                "filter": self._get_secret_scope_filters(secret_name),
            }
        ):
            self.CLIENT.delete_secret(request={"name": secret.name})

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_connected()

        # List all secrets.
        for secret in self.CLIENT.list_secrets(
            request={
                "parent": self.parent_name,
                "filter": self._get_secret_scope_filters(),
            }
        ):
            logger.info(f"Deleting Google secret {secret.name}")
            self.CLIENT.delete_secret(request={"name": secret.name})
parent_name: str property readonly

Construct the GCP parent path to the secret manager.

Returns:

Type Description
str

The parent path to the secret manager

delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_connected()

    # List all secrets.
    for secret in self.CLIENT.list_secrets(
        request={
            "parent": self.parent_name,
            "filter": self._get_secret_scope_filters(),
        }
    ):
        logger.info(f"Deleting Google secret {secret.name}")
        self.CLIENT.delete_secret(request={"name": secret.name})
delete_secret(self, secret_name)

Delete an existing secret by name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret no longer exists

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret by name.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret no longer exists
    """
    self.validate_secret_name_or_namespace(secret_name)
    self._ensure_client_connected()

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    # Go through all gcp secrets and delete the ones with the secret_name
    # as label.
    for secret in self.CLIENT.list_secrets(
        request={
            "parent": self.parent_name,
            "filter": self._get_secret_scope_filters(secret_name),
        }
    ):
        self.CLIENT.delete_secret(request={"name": secret.name})
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    return self._list_secrets()
get_secret(self, secret_name)

Get a secret by its name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Get a secret by its name.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name_or_namespace(secret_name)
    self._ensure_client_connected()

    zenml_secret: Optional[BaseSecretSchema] = None

    if self.scope == SecretsManagerScope.NONE:
        # Legacy secrets are mapped to multiple Google secrets, one for
        # each secret key

        secret_contents = {}
        zenml_schema_name = ""

        # List all secrets.
        for google_secret in self.CLIENT.list_secrets(
            request={
                "parent": self.parent_name,
                "filter": self._get_secret_scope_filters(secret_name),
            }
        ):
            secret_version_name = google_secret.name + "/versions/latest"

            response = self.CLIENT.access_secret_version(
                request={"name": secret_version_name}
            )

            secret_value = response.payload.data.decode("UTF-8")

            secret_key = remove_group_name_from_key(
                google_secret.name.split("/")[-1], secret_name
            )

            secret_contents[secret_key] = secret_value

            zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]

        if not secret_contents:
            raise KeyError(
                f"Can't find the specified secret '{secret_name}'"
            )

        secret_contents["name"] = secret_name

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=zenml_schema_name
        )
        zenml_secret = secret_schema(**secret_contents)

    else:
        # Scoped secrets are mapped 1-to-1 with Google secrets

        google_secret_name = self.CLIENT.secret_path(
            self.project_id,
            self._get_scoped_secret_name(
                secret_name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
            ),
        )

        try:
            # fetch the latest secret version
            google_secret = self.CLIENT.get_secret(name=google_secret_name)
        except google_exceptions.NotFound:
            raise KeyError(
                f"Can't find the specified secret '{secret_name}'"
            )

        # make sure the secret has the correct scope labels to filter out
        # unscoped secrets with similar names
        scope_labels = self._get_secret_scope_metadata(secret_name)
        # all scope labels need to be included in the google secret labels,
        # otherwise the secret does not belong to the current scope
        if not scope_labels.items() <= google_secret.labels.items():
            raise KeyError(
                f"Can't find the specified secret '{secret_name}'"
            )

        try:
            # fetch the latest secret version
            response = self.CLIENT.access_secret_version(
                name=f"{google_secret_name}/versions/latest"
            )
        except google_exceptions.NotFound:
            raise KeyError(
                f"Can't find the specified secret '{secret_name}'"
            )

        secret_value = response.payload.data.decode("UTF-8")
        zenml_secret = secret_from_dict(
            json.loads(secret_value), secret_name=secret_name
        )

    return zenml_secret
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self.validate_secret_name_or_namespace(secret.name)
    self._ensure_client_connected()

    if self._list_secrets(secret.name):
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    adjusted_content = self._convert_secret_content(secret)
    for k, v in adjusted_content.items():
        # Create the secret, this only creates an empty secret with the
        #  supplied name.
        gcp_secret = self.CLIENT.create_secret(
            request={
                "parent": self.parent_name,
                "secret_id": k,
                "secret": {
                    "replication": {"automatic": {}},
                    "labels": self._get_secret_labels(secret),
                },
            }
        )

        logger.debug("Created empty secret: %s", gcp_secret.name)

        self.CLIENT.add_secret_version(
            request={
                "parent": gcp_secret.name,
                "payload": {"data": str(v).encode()},
            }
        )

        logger.debug("Added value to secret.")
update_secret(self, secret)

Update an existing secret by creating new versions of the existing secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret by creating new versions of the existing secrets.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name_or_namespace(secret.name)
    self._ensure_client_connected()

    if not self._list_secrets(secret.name):
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    adjusted_content = self._convert_secret_content(secret)

    for k, v in adjusted_content.items():
        # Create the secret, this only creates an empty secret with the
        #  supplied name.
        google_secret_name = self.CLIENT.secret_path(self.project_id, k)
        payload = {"data": str(v).encode()}

        self.CLIENT.add_secret_version(
            request={"parent": google_secret_name, "payload": payload}
        )
validate_secret_name_or_namespace(name) classmethod

Validate a secret name or namespace.

A Google secret ID is a string with a maximum length of 255 characters and can contain uppercase and lowercase letters, numerals, and the hyphen (-) and underscore (_) characters. For scoped secrets, we have to limit the size of the name and namespace even further to allow space for both in the Google secret ID.

Given that we also save secret names and namespaces as labels, we are also limited by the limitation that Google imposes on label values: max 63 characters and must and must only contain lowercase letters, numerals and the hyphen (-) and underscore (_) characters

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(
    cls,
    name: str,
) -> None:
    """Validate a secret name or namespace.

    A Google secret ID is a string with a maximum length of 255 characters
    and can contain uppercase and lowercase letters, numerals, and the
    hyphen (-) and underscore (_) characters. For scoped secrets, we have to
    limit the size of the name and namespace even further to allow space for
    both in the Google secret ID.

    Given that we also save secret names and namespaces as labels, we are
    also limited by the limitation that Google imposes on label values: max
    63 characters and must and must only contain lowercase letters, numerals
    and the hyphen (-) and underscore (_) characters

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-z0-9_\-]+", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only lowercase alphanumeric characters and the hyphen (-) and "
            f"underscore (_) characters."
        )

    if name and len(name) > 63:
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. The length is "
            f"limited to maximum 63 characters."
        )
remove_group_name_from_key(combined_key_name, group_name)

Removes the secret group name from the secret key.

Parameters:

Name Type Description Default
combined_key_name str

Full name as it is within the gcp secrets manager

required
group_name str

Group name (the ZenML Secret name)

required

Returns:

Type Description
str

The cleaned key

Exceptions:

Type Description
RuntimeError

If the group name is not found in the key

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
    """Removes the secret group name from the secret key.

    Args:
        combined_key_name: Full name as it is within the gcp secrets manager
        group_name: Group name (the ZenML Secret name)

    Returns:
        The cleaned key

    Raises:
        RuntimeError: If the group name is not found in the key
    """
    if combined_key_name.startswith(group_name + "_"):
        return combined_key_name[len(group_name + "_") :]
    else:
        raise RuntimeError(
            f"Key-name `{combined_key_name}` does not have the "
            f"prefix `{group_name}`. Key could not be "
            f"extracted."
        )

step_operators special

Initialization for the VertexAI Step Operator.

vertex_step_operator

Implementation of a VertexAI step operator.

Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py

VertexStepOperator (BaseStepOperator, GoogleCredentialsMixin) pydantic-model

Step operator to run a step on Vertex AI.

This class defines code that can set up a Vertex AI environment and run the ZenML entrypoint command in it.

Attributes:

Name Type Description
region str

Region name, e.g., europe-west1.

project Optional[str]

GCP project name. If left None, inferred from the environment.

accelerator_type Optional[str]

Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType

accelerator_count int

Defines number of accelerators to be used for the job.

machine_type str

Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types

base_image Optional[str]

Base image for building the custom job container.

encryption_spec_key_name Optional[str]

Encryption spec key name.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
    """Step operator to run a step on Vertex AI.

    This class defines code that can set up a Vertex AI environment and run the
    ZenML entrypoint command in it.

    Attributes:
        region: Region name, e.g., `europe-west1`.
        project: GCP project name. If left None, inferred from the
            environment.
        accelerator_type: Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType
        accelerator_count: Defines number of accelerators to be
            used for the job.
        machine_type: Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
        base_image: Base image for building the custom job container.
        encryption_spec_key_name: Encryption spec key name.
    """

    region: str
    project: Optional[str] = None
    accelerator_type: Optional[str] = None
    accelerator_count: int = 0
    machine_type: str = "n1-standard-4"
    base_image: Optional[str] = None

    # customer managed encryption key resource name
    # will be applied to all Vertex AI resources if set
    encryption_spec_key_name: Optional[str] = None

    # Class configuration
    FLAVOR: ClassVar[str] = GCP_VERTEX_STEP_OPERATOR_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Returns:
            StackValidator: Validator for the stack.
        """

        def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
            # For now this only works on local orchestrator and GCP artifact
            #  store
            return (
                (
                    stack.orchestrator.FLAVOR == "local"
                    and stack.artifact_store.FLAVOR == "gcp"
                ),
                "Only local orchestrator and GCP artifact store are currently "
                "supported",
            )

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_ensure_local_orchestrator,
        )

    @property_validator("accelerator_type")
    def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
        """Validates that the accelerator type is valid.

        Args:
            accelerator_type: Accelerator type

        Raises:
            ValueError: If the accelerator type is not valid.
        """
        accepted_vals = list(
            aiplatform.gapic.AcceleratorType.__members__.keys()
        )
        if accelerator_type and accelerator_type.upper() not in accepted_vals:
            raise ValueError(
                f"Accelerator must be one of the following: {accepted_vals}"
            )

    def _build_and_push_docker_image(
        self,
        pipeline_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> str:
        """Builds and pushes a docker image.

        Args:
            pipeline_name: Pipeline name
            requirements: Requirements
            entrypoint_command: Entrypoint command

        Returns:
            Docker image name

        Raises:
            RuntimeError: If no container registry is found in the stack.
        """
        repo = Repository()
        container_registry = repo.active_stack.container_registry

        if not container_registry:
            raise RuntimeError("Missing container registry")

        registry_uri = container_registry.uri.rstrip("/")
        image_name = f"{registry_uri}/zenml-vertex:{pipeline_name}"

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            entrypoint=" ".join(entrypoint_command),
            requirements=set(requirements),
            base_image=self.base_image,
        )
        container_registry.push_image(image_name)
        return docker_utils.get_image_digest(image_name) or image_name

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Launches a step on Vertex AI.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.
            resource_configuration: The resource configuration for this step.

        Raises:
            RuntimeError: If the run fails.
            ConnectionError: If the run fails due to a connection error.
        """
        if resource_configuration.cpu_count or resource_configuration.memory:
            logger.warning(
                "Specifying cpus or memory is not supported for "
                "the Vertex step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different machine_type type like this: "
                "`zenml step-operator update %s "
                "--machine_type=<MACHINE_TYPE>`",
                self.name,
            )

        job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}

        # Step 1: Authenticate with Google
        credentials, project_id = self._get_authentication()
        if self.project:
            if self.project != project_id:
                logger.warning(
                    "Authenticated with project `%s`, but this orchestrator is "
                    "configured to use the project `%s`.",
                    project_id,
                    self.project,
                )
        else:
            self.project = project_id

        # Step 2: Build and push image
        image_name = self._build_and_push_docker_image(
            pipeline_name=pipeline_name,
            requirements=requirements,
            entrypoint_command=entrypoint_command,
        )

        # Step 3: Launch the job
        # The AI Platform services require regional API endpoints.
        client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for multiple requests.
        client = aiplatform.gapic.JobServiceClient(
            credentials=credentials, client_options=client_options
        )
        accelerator_count = (
            resource_configuration.gpu_count or self.accelerator_count
        )
        custom_job = {
            "display_name": run_name,
            "job_spec": {
                "worker_pool_specs": [
                    {
                        "machine_spec": {
                            "machine_type": self.machine_type,
                            "accelerator_type": self.accelerator_type,
                            "accelerator_count": accelerator_count
                            if self.accelerator_type
                            else 0,
                        },
                        "replica_count": 1,
                        "container_spec": {
                            "image_uri": image_name,
                            "command": [],
                            "args": [],
                        },
                    }
                ]
            },
            "labels": job_labels,
            "encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
            if self.encryption_spec_key_name
            else {},
        }
        logger.debug("Vertex AI Job=%s", custom_job)

        parent = f"projects/{self.project}/locations/{self.region}"
        logger.info(
            "Submitting custom job='%s', path='%s' to Vertex AI Training.",
            custom_job["display_name"],
            parent,
        )
        response = client.create_custom_job(
            parent=parent, custom_job=custom_job
        )
        logger.debug("Vertex AI response:", response)

        # Step 4: Monitor the job

        # Monitors the long-running operation by polling the job state
        # periodically, and retries the polling when a transient connectivity
        # issue is encountered.
        #
        # Long-running operation monitoring:
        #   The possible states of "get job" response can be found at
        #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
        #   where SUCCEEDED/FAILED/CANCELED are considered to be final states.
        #   The following logic will keep polling the state of the job until
        #   the job enters a final state.
        #
        # During the polling, if a connection error was encountered, the GET
        # request will be retried by recreating the Python API client to
        # refresh the lifecycle of the connection being used. See
        # https://github.com/googleapis/google-api-python-client/issues/218
        # for a detailed description of the problem. If the error persists for
        # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
        # will raise ConnectionError.
        retry_count = 0
        job_id = response.name

        while response.state not in VERTEX_JOB_STATES_COMPLETED:
            time.sleep(POLLING_INTERVAL_IN_SECONDS)
            try:
                response = client.get_custom_job(name=job_id)
                retry_count = 0
            # Handle transient connection error.
            except ConnectionError as err:
                if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
                    retry_count += 1
                    logger.warning(
                        "ConnectionError (%s) encountered when polling job: "
                        "%s. Trying to recreate the API client.",
                        err,
                        job_id,
                    )
                    # Recreate the Python API client.
                    client = aiplatform.gapic.JobServiceClient(
                        client_options=client_options
                    )
                else:
                    logger.error(
                        "Request failed after %s retries.",
                        CONNECTION_ERROR_RETRY_LIMIT,
                    )
                    raise

            if response.state in VERTEX_JOB_STATES_FAILED:
                err_msg = (
                    "Job '{}' did not succeed.  Detailed response {}.".format(
                        job_id, response
                    )
                )
                logger.error(err_msg)
                raise RuntimeError(err_msg)

        # Cloud training complete
        logger.info("Job '%s' successful.", job_id)
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Returns:

Type Description
StackValidator

Validator for the stack.

launch(self, pipeline_name, run_name, requirements, entrypoint_command, resource_configuration)

Launches a step on Vertex AI.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required
resource_configuration ResourceConfiguration

The resource configuration for this step.

required

Exceptions:

Type Description
RuntimeError

If the run fails.

ConnectionError

If the run fails due to a connection error.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
    resource_configuration: "ResourceConfiguration",
) -> None:
    """Launches a step on Vertex AI.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.
        resource_configuration: The resource configuration for this step.

    Raises:
        RuntimeError: If the run fails.
        ConnectionError: If the run fails due to a connection error.
    """
    if resource_configuration.cpu_count or resource_configuration.memory:
        logger.warning(
            "Specifying cpus or memory is not supported for "
            "the Vertex step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different machine_type type like this: "
            "`zenml step-operator update %s "
            "--machine_type=<MACHINE_TYPE>`",
            self.name,
        )

    job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}

    # Step 1: Authenticate with Google
    credentials, project_id = self._get_authentication()
    if self.project:
        if self.project != project_id:
            logger.warning(
                "Authenticated with project `%s`, but this orchestrator is "
                "configured to use the project `%s`.",
                project_id,
                self.project,
            )
    else:
        self.project = project_id

    # Step 2: Build and push image
    image_name = self._build_and_push_docker_image(
        pipeline_name=pipeline_name,
        requirements=requirements,
        entrypoint_command=entrypoint_command,
    )

    # Step 3: Launch the job
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.JobServiceClient(
        credentials=credentials, client_options=client_options
    )
    accelerator_count = (
        resource_configuration.gpu_count or self.accelerator_count
    )
    custom_job = {
        "display_name": run_name,
        "job_spec": {
            "worker_pool_specs": [
                {
                    "machine_spec": {
                        "machine_type": self.machine_type,
                        "accelerator_type": self.accelerator_type,
                        "accelerator_count": accelerator_count
                        if self.accelerator_type
                        else 0,
                    },
                    "replica_count": 1,
                    "container_spec": {
                        "image_uri": image_name,
                        "command": [],
                        "args": [],
                    },
                }
            ]
        },
        "labels": job_labels,
        "encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
        if self.encryption_spec_key_name
        else {},
    }
    logger.debug("Vertex AI Job=%s", custom_job)

    parent = f"projects/{self.project}/locations/{self.region}"
    logger.info(
        "Submitting custom job='%s', path='%s' to Vertex AI Training.",
        custom_job["display_name"],
        parent,
    )
    response = client.create_custom_job(
        parent=parent, custom_job=custom_job
    )
    logger.debug("Vertex AI response:", response)

    # Step 4: Monitor the job

    # Monitors the long-running operation by polling the job state
    # periodically, and retries the polling when a transient connectivity
    # issue is encountered.
    #
    # Long-running operation monitoring:
    #   The possible states of "get job" response can be found at
    #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
    #   where SUCCEEDED/FAILED/CANCELED are considered to be final states.
    #   The following logic will keep polling the state of the job until
    #   the job enters a final state.
    #
    # During the polling, if a connection error was encountered, the GET
    # request will be retried by recreating the Python API client to
    # refresh the lifecycle of the connection being used. See
    # https://github.com/googleapis/google-api-python-client/issues/218
    # for a detailed description of the problem. If the error persists for
    # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
    # will raise ConnectionError.
    retry_count = 0
    job_id = response.name

    while response.state not in VERTEX_JOB_STATES_COMPLETED:
        time.sleep(POLLING_INTERVAL_IN_SECONDS)
        try:
            response = client.get_custom_job(name=job_id)
            retry_count = 0
        # Handle transient connection error.
        except ConnectionError as err:
            if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
                retry_count += 1
                logger.warning(
                    "ConnectionError (%s) encountered when polling job: "
                    "%s. Trying to recreate the API client.",
                    err,
                    job_id,
                )
                # Recreate the Python API client.
                client = aiplatform.gapic.JobServiceClient(
                    client_options=client_options
                )
            else:
                logger.error(
                    "Request failed after %s retries.",
                    CONNECTION_ERROR_RETRY_LIMIT,
                )
                raise

        if response.state in VERTEX_JOB_STATES_FAILED:
            err_msg = (
                "Job '{}' did not succeed.  Detailed response {}.".format(
                    job_id, response
                )
            )
            logger.error(err_msg)
            raise RuntimeError(err_msg)

    # Cloud training complete
    logger.info("Job '%s' successful.", job_id)
validate_accelerator_enum(accelerator_type) classmethod

Validates that the accelerator type is valid.

Parameters:

Name Type Description Default
accelerator_type Optional[str]

Accelerator type

required

Exceptions:

Type Description
ValueError

If the accelerator type is not valid.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
@property_validator("accelerator_type")
def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
    """Validates that the accelerator type is valid.

    Args:
        accelerator_type: Accelerator type

    Raises:
        ValueError: If the accelerator type is not valid.
    """
    accepted_vals = list(
        aiplatform.gapic.AcceleratorType.__members__.keys()
    )
    if accelerator_type and accelerator_type.upper() not in accepted_vals:
        raise ValueError(
            f"Accelerator must be one of the following: {accepted_vals}"
        )

github special

Initialization of the GitHub ZenML integration.

The GitHub integration provides a way to orchestrate pipelines using GitHub Actions.

GitHubIntegration (Integration)

Definition of GitHub integration for ZenML.

Source code in zenml/integrations/github/__init__.py
class GitHubIntegration(Integration):
    """Definition of GitHub integration for ZenML."""

    NAME = GITHUB
    REQUIREMENTS: List[str] = ["PyNaCl~=1.5.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the GitHub integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=GITHUB_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.github.orchestrators.GitHubActionsOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GITHUB_SECRET_MANAGER_FLAVOR,
                source="zenml.integrations.github.secrets_managers.GitHubSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the GitHub integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/github/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the GitHub integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=GITHUB_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.github.orchestrators.GitHubActionsOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GITHUB_SECRET_MANAGER_FLAVOR,
            source="zenml.integrations.github.secrets_managers.GitHubSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
    ]

orchestrators special

Initialization of the GitHub Actions Orchestrator.

github_actions_entrypoint_configuration

Implementation of the GitHub Actions Orchestrator entrypoint.

GitHubActionsEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on GitHub Actions runners.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
class GitHubActionsEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on GitHub Actions runners."""

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """GitHub Actions specific entrypoint options.

        Returns:
            Set with the custom run id option.
        """
        return {RUN_ID_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: BaseStep, **kwargs: Any
    ) -> List[str]:
        """Adds a run id argument for the entrypoint.

        Args:
            step: Step for which the arguments are passed.
            **kwargs: Additional keyword arguments.

        Returns:
            GitHub Actions placeholder for the run id option.
        """
        # These placeholders in the workflow file will be replaced with
        # concrete values by the GitHub Actions runner
        run_id = (
            "${{ github.run_id }}_${{ github.run_number }}_"
            "${{ github.run_attempt }}"
        )
        return [f"--{RUN_ID_OPTION}", run_id]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the pipeline run name.

        Args:
            pipeline_name: Name of the pipeline which will run.

        Returns:
            The run name.
        """
        run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
        return f"{pipeline_name}-{run_id}"
get_custom_entrypoint_arguments(step, **kwargs) classmethod

Adds a run id argument for the entrypoint.

Parameters:

Name Type Description Default
step BaseStep

Step for which the arguments are passed.

required
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[str]

GitHub Actions placeholder for the run id option.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: BaseStep, **kwargs: Any
) -> List[str]:
    """Adds a run id argument for the entrypoint.

    Args:
        step: Step for which the arguments are passed.
        **kwargs: Additional keyword arguments.

    Returns:
        GitHub Actions placeholder for the run id option.
    """
    # These placeholders in the workflow file will be replaced with
    # concrete values by the GitHub Actions runner
    run_id = (
        "${{ github.run_id }}_${{ github.run_number }}_"
        "${{ github.run_attempt }}"
    )
    return [f"--{RUN_ID_OPTION}", run_id]
get_custom_entrypoint_options() classmethod

GitHub Actions specific entrypoint options.

Returns:

Type Description
Set[str]

Set with the custom run id option.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """GitHub Actions specific entrypoint options.

    Returns:
        Set with the custom run id option.
    """
    return {RUN_ID_OPTION}
get_run_name(self, pipeline_name)

Returns the pipeline run name.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which will run.

required

Returns:

Type Description
str

The run name.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the pipeline run name.

    Args:
        pipeline_name: Name of the pipeline which will run.

    Returns:
        The run name.
    """
    run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
    return f"{pipeline_name}-{run_id}"
github_actions_orchestrator

Implementation of the GitHub Actions Orchestrator.

GitHubActionsOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using GitHub Actions.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of a docker image that should be used as the base for the image that will be run on GitHub Action runners. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

skip_dirty_repository_check bool

If True, this orchestrator will not raise an exception when trying to run a pipeline while there are still untracked/uncommitted files in the git repository.

skip_github_repository_check bool

If True, the orchestrator will not check if your git repository is pointing to a GitHub remote.

push bool

If True, this orchestrator will automatically commit and push the GitHub workflow file when running a pipeline. If False, the workflow file will be written to the correct location but needs to be committed and pushed manually.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
class GitHubActionsOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using GitHub Actions.

    Attributes:
        custom_docker_base_image_name: Name of a docker image that should be
            used as the base for the image that will be run on GitHub Action
            runners. If no custom image is given, a basic image of the active
            ZenML version will be used. **Note**: This image needs to have
            ZenML installed, otherwise the pipeline execution will fail. For
            that reason, you might want to extend the ZenML docker images
            found here: https://hub.docker.com/r/zenmldocker/zenml/
        skip_dirty_repository_check: If `True`, this orchestrator will not
            raise an exception when trying to run a pipeline while there are
            still untracked/uncommitted files in the git repository.
        skip_github_repository_check: If `True`, the orchestrator will not check
            if your git repository is pointing to a GitHub remote.
        push: If `True`, this orchestrator will automatically commit and push
            the GitHub workflow file when running a pipeline. If `False`, the
            workflow file will be written to the correct location but needs to
            be committed and pushed manually.
    """

    custom_docker_base_image_name: Optional[str] = None
    skip_dirty_repository_check: bool = False
    skip_github_repository_check: bool = False
    push: bool = False

    _git_repo: Optional[Repo] = None

    # Class configuration
    FLAVOR: ClassVar[str] = GITHUB_ORCHESTRATOR_FLAVOR

    @property
    def git_repo(self) -> Repo:
        """Returns the git repository for the current working directory.

        Returns:
            Git repository for the current working directory.

        Raises:
            RuntimeError: If there is no git repository for the current working
                directory or the repository remote is not pointing to GitHub.
        """
        if not self._git_repo:
            try:
                self._git_repo = Repo(search_parent_directories=True)
            except InvalidGitRepositoryError:
                raise RuntimeError(
                    "Unable to find git repository in current working "
                    f"directory {os.getcwd()} or its parent directories."
                )

            remote_url = self.git_repo.remote().url
            is_github_repo = any(
                remote_url.startswith(prefix)
                for prefix in GITHUB_REMOTE_URL_PREFIXES
            )
            if not (is_github_repo or self.skip_github_repository_check):
                raise RuntimeError(
                    f"The remote URL '{remote_url}' of your git repo "
                    f"({self._git_repo.git_dir}) is not pointing to a GitHub "
                    "repository. The GitHub Actions orchestrator runs "
                    "pipelines using GitHub Actions and therefore only works "
                    "with GitHub repositories. If you want to skip this check "
                    "and run this orchestrator anyway, run: \n"
                    f"`zenml orchestrator update {self.name} "
                    "--skip_github_repository_check=true`"
                )

        return self._git_repo

    @property
    def workflow_directory(self) -> str:
        """Returns path to the GitHub workflows directory.

        Returns:
            The GitHub workflows directory.
        """
        assert self.git_repo.working_dir
        return os.path.join(self.git_repo.working_dir, ".github", "workflows")

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validator that ensures that the stack is compatible.

        Makes sure that the stack contains a container registry and only
        remote components.

        Returns:
            The stack validator.
        """

        def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.is_local:
                return False, (
                    "The GitHub Actions orchestrator requires a remote "
                    f"container registry, but the '{container_registry.name}' "
                    "container registry of your active stack points to a local "
                    f"URI '{container_registry.uri}'. Please make sure stacks "
                    "with a GitHub Actions orchestrator always contain remote "
                    "container registries."
                )

            if container_registry.requires_authentication:
                return False, (
                    "The GitHub Actions orchestrator currently only works with "
                    "GitHub container registries or public container "
                    f"registries, but your {container_registry.FLAVOR} "
                    f"container registry '{container_registry.name}' requires "
                    "authentication."
                )

            for component in stack.components.values():
                if component.local_path:
                    return False, (
                        "The GitHub Actions orchestrator runs pipelines on "
                        "remote GitHub Actions runners, but the "
                        f"'{component.name}' {component.TYPE.value} of your "
                        "active stack is a local component. Please make sure "
                        "to only use remote stack components in combination "
                        "with the GitHub Actions orchestrator. "
                    )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_local_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag.

        Args:
            pipeline_name: Name of the pipeline for which to generate a docker
                image name.

        Returns:
            The docker image name.
        """
        container_registry = Repository().active_stack.container_registry
        assert container_registry  # should never happen due to validation
        return f"{container_registry.uri}/zenml-github-actions:{pipeline_name}"

    def _docker_login_step(
        self,
        container_registry: BaseContainerRegistry,
    ) -> Optional[Dict[str, Any]]:
        """GitHub Actions step for authenticating with the container registry.

        Args:
            container_registry: The container registry which (potentially)
                requires a step to authenticate.

        Returns:
            Dictionary specifying the GitHub Actions step for authenticating
            with the container registry if that is required, `None` otherwise.
        """
        if (
            isinstance(container_registry, GitHubContainerRegistry)
            and container_registry.automatic_token_authentication
        ):
            # Use GitHub Actions specific placeholder if the container registry
            # specifies automatic token authentication
            username = "${{ github.actor }}"
            password = "${{ secrets.GITHUB_TOKEN }}"
        # TODO: Uncomment these lines once we support different private
        #  container registries in GitHub Actions
        # elif container_registry.requires_authentication:
        #     username = cast(str, container_registry.username)
        #     password = cast(str, container_registry.password)
        else:
            return None

        return {
            "name": "Authenticate with the container registry",
            "uses": DOCKER_LOGIN_ACTION,
            "with": {
                "registry": container_registry.uri,
                "username": username,
                "password": password,
            },
        }

    def _write_environment_file_step(
        self,
        file_name: str,
        secrets_manager: Optional[BaseSecretsManager] = None,
    ) -> Optional[Dict[str, Any]]:
        """GitHub Actions step for writing secrets to an environment file.

        Args:
            file_name: Name of the environment file that should be written.
            secrets_manager: Secrets manager that will be used to read secrets
                during pipeline execution.

        Returns:
            Dictionary specifying the GitHub Actions step for writing the
            environment file.
        """
        if not isinstance(secrets_manager, GitHubSecretsManager):
            return None

        # Always include the environment variable that specifies whether
        # we're running in a GitHub Action workflow so the secret manager knows
        # how to query secret values
        command = (
            f'echo {ENV_IN_GITHUB_ACTIONS}="${ENV_IN_GITHUB_ACTIONS}" '
            f"> {file_name}; "
        )

        # Write all ZenML secrets into the environment file. Explicitly writing
        # these `${{ secrets.<SECRET_NAME> }}` placeholders into the workflow
        # yaml is the only way for us to access the GitHub secrets in a GitHub
        # Actions workflow.
        append_secret_placeholder = (
            "echo {secret_name}=${{{{ secrets.{secret_name} }}}} >> {file}; "
        )
        for secret_name in secrets_manager.get_all_secret_keys(
            include_prefix=True
        ):
            command += append_secret_placeholder.format(
                secret_name=secret_name, file=file_name
            )

        return {
            "name": "Write environment file",
            "run": command,
        }

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Builds and uploads a docker image.

        Args:
            pipeline: The pipeline for which the image is built.
            stack: The stack on which the pipeline will be executed.
            runtime_configuration: Runtime configuration for the pipeline run.

        Raises:
            RuntimeError: If the orchestrator should only run in a clean git
                repository and the repository is dirty.
        """
        if not self.skip_dirty_repository_check and self.git_repo.is_dirty(
            untracked_files=True
        ):
            raise RuntimeError(
                "Trying to run a pipeline from within a dirty (=containing "
                "untracked/uncommitted files) git repository."
                "If you want this orchestrator to skip the dirty repo check in "
                f"the future, run\n `zenml orchestrator update {self.name} "
                "--skip_dirty_repository_check=true`"
            )

        image_name = self.get_docker_image_name(pipeline.name)
        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug(
            "Github actions docker image requirements: %s", requirements
        )

        docker_utils.build_docker_image(
            build_context_path=source_utils.get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )

        assert stack.container_registry  # should never happen due to validation
        stack.container_registry.push_image(image_name)

        # Store the docker image digest in the runtime configuration so it gets
        # tracked in the ZenStore
        image_digest = docker_utils.get_image_digest(image_name) or image_name
        runtime_configuration["docker_image"] = image_digest

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List[BaseStep],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Writes a GitHub Action workflow yaml and optionally pushes it.

        Args:
             sorted_steps: List of sorted steps
             pipeline: Zenml Pipeline instance
             pb2_pipeline: Protobuf Pipeline instance
             stack: The stack the pipeline was run on
             runtime_configuration: The Runtime configuration of the current run

        Raises:
            ValueError: If a schedule without a cron expression or with an
                invalid cron expression is passed.
        """
        schedule = runtime_configuration.schedule

        workflow_name = pipeline.name
        if schedule:
            # Add a suffix to the workflow filename so we don't overwrite
            # scheduled pipeline by future schedules or single pipeline runs.
            datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
            workflow_name += f"-scheduled-{datetime_string}"

        workflow_path = os.path.join(
            self.workflow_directory,
            f"{workflow_name}.yaml",
        )

        # Store the encoded pb2 pipeline once as an environment variable.
        # We will replace the entrypoint argument later to reduce the size
        # of the workflow file.
        encoded_pb2_pipeline = string_utils.b64_encode(
            json_format.MessageToJson(pb2_pipeline)
        )
        workflow_dict: Dict[str, Any] = {
            "name": workflow_name,
            "env": {ENV_ENCODED_ZENML_PIPELINE: encoded_pb2_pipeline},
        }

        if schedule:
            if not schedule.cron_expression:
                raise ValueError(
                    "GitHub Action workflows can only be scheduled using cron "
                    "expressions and not using a periodic schedule. If you "
                    "want to schedule pipelines using this GitHub Action "
                    "orchestrator, please include a cron expression in your "
                    "schedule object. For more information on GitHub workflow "
                    "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
                )

            # GitHub workflows requires a schedule interval of at least 5
            # minutes. Invalid cron expressions would be something like
            # `*/3 * * * *` (all stars except the first part of the expression,
            # which will have the format `*/minute_interval`)
            if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
                raise ValueError(
                    "GitHub workflows requires a schedule interval of at "
                    "least 5 minutes which is incompatible with your cron "
                    f"expression '{schedule.cron_expression}'. An example of a "
                    "valid cron expression would be '* 1 * * *' to run "
                    "every hour. For more information on GitHub workflow "
                    "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
                )

            logger.warning(
                "GitHub only runs scheduled workflows once the "
                "workflow file is merged to the default branch of the "
                "repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
                "Please make sure to merge your current branch into the "
                "default branch for this scheduled pipeline to run."
            )
            workflow_dict["on"] = {
                "schedule": [{"cron": schedule.cron_expression}]
            }
        else:
            # The pipeline should only run once. The only fool-proof way to
            # only execute a workflow once seems to be running on specific tags.
            # We don't want to create tags for each pipeline run though, so
            # instead we only run this workflow if the workflow file is
            # modified. As long as users don't manually modify these files this
            # should be sufficient.
            workflow_path_in_repo = os.path.relpath(
                workflow_path, self.git_repo.working_dir
            )
            workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}

        image_name = self.get_docker_image_name(pipeline.name)
        image_name = docker_utils.get_image_digest(image_name) or image_name

        # Prepare the step that writes an environment file which will get
        # passed to the docker image
        env_file_name = ".zenml_docker_env"
        write_env_file_step = self._write_environment_file_step(
            file_name=env_file_name, secrets_manager=stack.secrets_manager
        )
        docker_run_args = (
            ["--env-file", env_file_name] if write_env_file_step else []
        )

        # Prepare the docker login step if necessary
        container_registry = stack.container_registry
        assert container_registry
        docker_login_step = self._docker_login_step(container_registry)

        # The base command that each job will execute with specific arguments
        base_command = [
            "docker",
            "run",
            *docker_run_args,
            image_name,
        ] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()

        jobs = {}
        for step in sorted_steps:
            if self.requires_resources_in_orchestration_environment(step):
                logger.warning(
                    "Specifying step resources is not supported for the "
                    "GitHub Actions orchestrator, ignoring resource "
                    "configuration for step %s.",
                    step.name,
                )

            job_steps = []

            # Copy the shared dicts here to avoid creating yaml anchors (which
            # are currently not supported in GitHub workflow yaml files)
            if write_env_file_step:
                job_steps.append(copy.deepcopy(write_env_file_step))

            if docker_login_step:
                job_steps.append(copy.deepcopy(docker_login_step))

            entrypoint_args = (
                GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                )
            )

            # Replace the encoded string by a global environment variable to
            # keep the workflow file small
            index = entrypoint_args.index(f"--{PIPELINE_JSON_OPTION}")
            entrypoint_args[index + 1] = f"${ENV_ENCODED_ZENML_PIPELINE}"

            command = base_command + entrypoint_args
            docker_run_step = {
                "name": "Run the docker image",
                "run": " ".join(command),
            }

            job_steps.append(docker_run_step)
            job_dict = {
                "runs-on": "ubuntu-latest",
                "needs": self.get_upstream_step_names(
                    step=step, pb2_pipeline=pb2_pipeline
                ),
                "steps": job_steps,
            }
            jobs[step.name] = job_dict

        workflow_dict["jobs"] = jobs

        fileio.makedirs(self.workflow_directory)
        yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
        logger.info("Wrote GitHub workflow file to %s", workflow_path)

        if self.push:
            # Add, commit and push the pipeline workflow yaml
            self.git_repo.index.add(workflow_path)
            self.git_repo.index.commit(
                "[ZenML GitHub Actions Orchestrator] Add github workflow for "
                f"pipeline {pipeline.name}."
            )
            self.git_repo.remote().push()
            logger.info("Pushed workflow file '%s'", workflow_path)
        else:
            logger.info(
                "Automatically committing and pushing is disabled for this "
                "orchestrator. To run the pipeline, you'll have to commit and "
                "push the workflow file '%s' manually.\n"
                "If you want to update this orchestrator to automatically "
                "commit and push in the future, run "
                "`zenml orchestrator update %s --push=true`",
                workflow_path,
                self.name,
            )
git_repo: Repo property readonly

Returns the git repository for the current working directory.

Returns:

Type Description
Repo

Git repository for the current working directory.

Exceptions:

Type Description
RuntimeError

If there is no git repository for the current working directory or the repository remote is not pointing to GitHub.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validator that ensures that the stack is compatible.

Makes sure that the stack contains a container registry and only remote components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

The stack validator.

workflow_directory: str property readonly

Returns path to the GitHub workflows directory.

Returns:

Type Description
str

The GitHub workflows directory.

get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline for which to generate a docker image name.

required

Returns:

Type Description
str

The docker image name.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag.

    Args:
        pipeline_name: Name of the pipeline for which to generate a docker
            image name.

    Returns:
        The docker image name.
    """
    container_registry = Repository().active_stack.container_registry
    assert container_registry  # should never happen due to validation
    return f"{container_registry.uri}/zenml-github-actions:{pipeline_name}"
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Writes a GitHub Action workflow yaml and optionally pushes it.

Parameters:

Name Type Description Default
sorted_steps List[zenml.steps.base_step.BaseStep]

List of sorted steps

required
pipeline BasePipeline

Zenml Pipeline instance

required
pb2_pipeline Pipeline

Protobuf Pipeline instance

required
stack Stack

The stack the pipeline was run on

required
runtime_configuration RuntimeConfiguration

The Runtime configuration of the current run

required

Exceptions:

Type Description
ValueError

If a schedule without a cron expression or with an invalid cron expression is passed.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List[BaseStep],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Writes a GitHub Action workflow yaml and optionally pushes it.

    Args:
         sorted_steps: List of sorted steps
         pipeline: Zenml Pipeline instance
         pb2_pipeline: Protobuf Pipeline instance
         stack: The stack the pipeline was run on
         runtime_configuration: The Runtime configuration of the current run

    Raises:
        ValueError: If a schedule without a cron expression or with an
            invalid cron expression is passed.
    """
    schedule = runtime_configuration.schedule

    workflow_name = pipeline.name
    if schedule:
        # Add a suffix to the workflow filename so we don't overwrite
        # scheduled pipeline by future schedules or single pipeline runs.
        datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
        workflow_name += f"-scheduled-{datetime_string}"

    workflow_path = os.path.join(
        self.workflow_directory,
        f"{workflow_name}.yaml",
    )

    # Store the encoded pb2 pipeline once as an environment variable.
    # We will replace the entrypoint argument later to reduce the size
    # of the workflow file.
    encoded_pb2_pipeline = string_utils.b64_encode(
        json_format.MessageToJson(pb2_pipeline)
    )
    workflow_dict: Dict[str, Any] = {
        "name": workflow_name,
        "env": {ENV_ENCODED_ZENML_PIPELINE: encoded_pb2_pipeline},
    }

    if schedule:
        if not schedule.cron_expression:
            raise ValueError(
                "GitHub Action workflows can only be scheduled using cron "
                "expressions and not using a periodic schedule. If you "
                "want to schedule pipelines using this GitHub Action "
                "orchestrator, please include a cron expression in your "
                "schedule object. For more information on GitHub workflow "
                "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
            )

        # GitHub workflows requires a schedule interval of at least 5
        # minutes. Invalid cron expressions would be something like
        # `*/3 * * * *` (all stars except the first part of the expression,
        # which will have the format `*/minute_interval`)
        if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
            raise ValueError(
                "GitHub workflows requires a schedule interval of at "
                "least 5 minutes which is incompatible with your cron "
                f"expression '{schedule.cron_expression}'. An example of a "
                "valid cron expression would be '* 1 * * *' to run "
                "every hour. For more information on GitHub workflow "
                "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
            )

        logger.warning(
            "GitHub only runs scheduled workflows once the "
            "workflow file is merged to the default branch of the "
            "repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
            "Please make sure to merge your current branch into the "
            "default branch for this scheduled pipeline to run."
        )
        workflow_dict["on"] = {
            "schedule": [{"cron": schedule.cron_expression}]
        }
    else:
        # The pipeline should only run once. The only fool-proof way to
        # only execute a workflow once seems to be running on specific tags.
        # We don't want to create tags for each pipeline run though, so
        # instead we only run this workflow if the workflow file is
        # modified. As long as users don't manually modify these files this
        # should be sufficient.
        workflow_path_in_repo = os.path.relpath(
            workflow_path, self.git_repo.working_dir
        )
        workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}

    image_name = self.get_docker_image_name(pipeline.name)
    image_name = docker_utils.get_image_digest(image_name) or image_name

    # Prepare the step that writes an environment file which will get
    # passed to the docker image
    env_file_name = ".zenml_docker_env"
    write_env_file_step = self._write_environment_file_step(
        file_name=env_file_name, secrets_manager=stack.secrets_manager
    )
    docker_run_args = (
        ["--env-file", env_file_name] if write_env_file_step else []
    )

    # Prepare the docker login step if necessary
    container_registry = stack.container_registry
    assert container_registry
    docker_login_step = self._docker_login_step(container_registry)

    # The base command that each job will execute with specific arguments
    base_command = [
        "docker",
        "run",
        *docker_run_args,
        image_name,
    ] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()

    jobs = {}
    for step in sorted_steps:
        if self.requires_resources_in_orchestration_environment(step):
            logger.warning(
                "Specifying step resources is not supported for the "
                "GitHub Actions orchestrator, ignoring resource "
                "configuration for step %s.",
                step.name,
            )

        job_steps = []

        # Copy the shared dicts here to avoid creating yaml anchors (which
        # are currently not supported in GitHub workflow yaml files)
        if write_env_file_step:
            job_steps.append(copy.deepcopy(write_env_file_step))

        if docker_login_step:
            job_steps.append(copy.deepcopy(docker_login_step))

        entrypoint_args = (
            GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
                step=step,
                pb2_pipeline=pb2_pipeline,
            )
        )

        # Replace the encoded string by a global environment variable to
        # keep the workflow file small
        index = entrypoint_args.index(f"--{PIPELINE_JSON_OPTION}")
        entrypoint_args[index + 1] = f"${ENV_ENCODED_ZENML_PIPELINE}"

        command = base_command + entrypoint_args
        docker_run_step = {
            "name": "Run the docker image",
            "run": " ".join(command),
        }

        job_steps.append(docker_run_step)
        job_dict = {
            "runs-on": "ubuntu-latest",
            "needs": self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            ),
            "steps": job_steps,
        }
        jobs[step.name] = job_dict

    workflow_dict["jobs"] = jobs

    fileio.makedirs(self.workflow_directory)
    yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
    logger.info("Wrote GitHub workflow file to %s", workflow_path)

    if self.push:
        # Add, commit and push the pipeline workflow yaml
        self.git_repo.index.add(workflow_path)
        self.git_repo.index.commit(
            "[ZenML GitHub Actions Orchestrator] Add github workflow for "
            f"pipeline {pipeline.name}."
        )
        self.git_repo.remote().push()
        logger.info("Pushed workflow file '%s'", workflow_path)
    else:
        logger.info(
            "Automatically committing and pushing is disabled for this "
            "orchestrator. To run the pipeline, you'll have to commit and "
            "push the workflow file '%s' manually.\n"
            "If you want to update this orchestrator to automatically "
            "commit and push in the future, run "
            "`zenml orchestrator update %s --push=true`",
            workflow_path,
            self.name,
        )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Builds and uploads a docker image.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline for which the image is built.

required
stack Stack

The stack on which the pipeline will be executed.

required
runtime_configuration RuntimeConfiguration

Runtime configuration for the pipeline run.

required

Exceptions:

Type Description
RuntimeError

If the orchestrator should only run in a clean git repository and the repository is dirty.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Builds and uploads a docker image.

    Args:
        pipeline: The pipeline for which the image is built.
        stack: The stack on which the pipeline will be executed.
        runtime_configuration: Runtime configuration for the pipeline run.

    Raises:
        RuntimeError: If the orchestrator should only run in a clean git
            repository and the repository is dirty.
    """
    if not self.skip_dirty_repository_check and self.git_repo.is_dirty(
        untracked_files=True
    ):
        raise RuntimeError(
            "Trying to run a pipeline from within a dirty (=containing "
            "untracked/uncommitted files) git repository."
            "If you want this orchestrator to skip the dirty repo check in "
            f"the future, run\n `zenml orchestrator update {self.name} "
            "--skip_dirty_repository_check=true`"
        )

    image_name = self.get_docker_image_name(pipeline.name)
    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug(
        "Github actions docker image requirements: %s", requirements
    )

    docker_utils.build_docker_image(
        build_context_path=source_utils.get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )

    assert stack.container_registry  # should never happen due to validation
    stack.container_registry.push_image(image_name)

    # Store the docker image digest in the runtime configuration so it gets
    # tracked in the ZenStore
    image_digest = docker_utils.get_image_digest(image_name) or image_name
    runtime_configuration["docker_image"] = image_digest

secrets_managers special

Initialization of the GitHub Secrets Manager.

github_secrets_manager

Implementation of the GitHub Secrets Manager.

GitHubSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the GitHub secrets manager.

Attributes:

Name Type Description
owner str

The owner (either individual or organization) of the repository.

repository str

Name of the GitHub repository.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
class GitHubSecretsManager(BaseSecretsManager):
    """Class to interact with the GitHub secrets manager.

    Attributes:
        owner: The owner (either individual or organization) of the repository.
        repository: Name of the GitHub repository.
    """

    owner: str
    repository: str

    _session: Optional[requests.Session] = None

    # Class configuration
    FLAVOR: ClassVar[str] = GITHUB_SECRET_MANAGER_FLAVOR

    @property
    def post_registration_message(self) -> Optional[str]:
        """Info message regarding GitHub API authentication env variables.

        Returns:
            The info message.
        """
        return AUTHENTICATION_CREDENTIALS_MESSAGE

    @property
    def session(self) -> requests.Session:
        """Session to send requests to the GitHub API.

        Returns:
            Session to use for GitHub API calls.

        Raises:
            RuntimeError: If authentication credentials for the GitHub API are
                not set.
        """
        if not self._session:
            session = requests.Session()
            github_username = os.getenv(ENV_GITHUB_USERNAME)
            authentication_token = os.getenv(ENV_GITHUB_AUTHENTICATION_TOKEN)

            if not github_username or not authentication_token:
                raise RuntimeError(
                    "Missing authentication credentials for GitHub secrets "
                    "manager. " + AUTHENTICATION_CREDENTIALS_MESSAGE
                )

            session.auth = HTTPBasicAuth(github_username, authentication_token)
            session.headers["Accept"] = "application/vnd.github.v3+json"
            self._session = session

        return self._session

    def _send_request(
        self, method: str, resource: Optional[str] = None, **kwargs: Any
    ) -> requests.Response:
        """Sends an HTTP request to the GitHub API.

        Args:
            method: Method of the HTTP request that should be sent.
            resource: Optional resource to which the request should be sent. If
                none is given, the default GitHub API secrets endpoint will be
                used.
            **kwargs: Will be passed to the `requests` library.

        Returns:
            HTTP response.

        # noqa: DAR402

        Raises:
            HTTPError: If the request failed due to a client or server error.
        """
        url = (
            f"https://api.github.com/repos/{self.owner}/{self.repository}"
            f"/actions/secrets"
        )
        if resource:
            url += resource

        response = self.session.request(method=method, url=url, **kwargs)
        # Raise an exception in case of a client or server error
        response.raise_for_status()

        return response

    def _encrypt_secret(self, secret_value: str) -> Tuple[str, str]:
        """Encrypts a secret value.

        This method first fetches a public key from the GitHub API and then uses
        this key to encrypt the secret value. This is needed in order to
        register GitHub secrets using the API.

        Args:
            secret_value: Secret value to encrypt.

        Returns:
            The encrypted secret value and the key id of the GitHub public key.
        """
        from nacl.encoding import Base64Encoder
        from nacl.public import PublicKey, SealedBox

        response_json = self._send_request("GET", resource="/public-key").json()
        public_key = PublicKey(
            response_json["key"].encode("utf-8"), Base64Encoder
        )
        sealed_box = SealedBox(public_key)
        encrypted_bytes = sealed_box.encrypt(secret_value.encode("utf-8"))
        encrypted_string = base64.b64encode(encrypted_bytes).decode("utf-8")
        return encrypted_string, cast(str, response_json["key_id"])

    def _has_secret(self, secret_name: str) -> bool:
        """Checks whether a secret exists for the given name.

        Args:
            secret_name: Name of the secret which should be checked.

        Returns:
            `True` if a secret with the given name exists, `False` otherwise.
        """
        secret_name = _convert_secret_name(secret_name, remove_prefix=True)
        return secret_name in self.get_all_secret_keys(include_prefix=False)

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets the value of a secret.

        This method only works when called from within a GitHub Actions
        environment.

        Args:
            secret_name: The name of the secret to get.

        Returns:
            The secret.

        Raises:
            KeyError: If a secret with this name doesn't exist.
            RuntimeError: If not inside a GitHub Actions environments.
        """
        full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
        # Raise a KeyError if the secret doesn't exist. We can do that even
        # if we're not inside a GitHub Actions environment
        if not self._has_secret(secret_name):
            raise KeyError(
                f"Unable to find secret '{secret_name}'. Please check the "
                "GitHub UI to see if a **Repository** secret called "
                f"'{full_secret_name}' exists. (ZenML uses the "
                f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
                "secrets from other GitHub secrets)"
            )

        if not inside_github_action_environment():
            stack_name = Repository().active_stack_name
            commands = [
                f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
                "zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
                "--flavor=local",
                "zenml stack update <NEW_STACK_NAME> "
                "--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
                "zenml stack set <NEW_STACK_NAME>",
                f"zenml secret register {secret_name} ...",
            ]

            raise RuntimeError(
                "Getting GitHub secrets is only possible within a GitHub "
                "Actions workflow. If you need this secret to access "
                "stack components (e.g. your metadata store to fetch pipelines "
                "during the post-execution workflow) locally, you need to "
                "register this secret in a different secrets manager. "
                "You can do this by running the following commands: \n\n"
                + "\n".join(commands)
            )

        # If we're running inside an GitHub Actions environment using the a
        # workflow generated by the GitHub Actions orchestrator, all ZenML
        # secrets stored in the GitHub secrets manager will be accessible as
        # environment variables
        secret_value = cast(str, os.getenv(full_secret_name))

        secret_dict = json.loads(string_utils.b64_decode(secret_value))
        schema_class = SecretSchemaClassRegistry.get_class(
            secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
        )
        secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]

        return schema_class(name=secret_name, **secret_content)

    def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
        """Get all secret keys.

        If we're running inside a GitHub Actions environment, this will return
        the names of all environment variables starting with a ZenML internal
        prefix. Otherwise, this will return all GitHub **Repository** secrets
        created by ZenML.

        Args:
            include_prefix: Whether or not the internal prefix that is used to
                differentiate ZenML secrets from other GitHub secrets should be
                included in the returned names.

        Returns:
            List of all secret keys.
        """
        if inside_github_action_environment():
            potential_secret_keys = list(os.environ)
        else:
            logger.info(
                "Fetching list of secrets for repository %s/%s",
                self.owner,
                self.repository,
            )
            response = self._send_request("GET", params={"per_page": 100})
            potential_secret_keys = [
                secret_dict["name"]
                for secret_dict in response.json()["secrets"]
            ]

        keys = [
            _convert_secret_name(key, remove_prefix=not include_prefix)
            for key in potential_secret_keys
            if key.startswith(GITHUB_SECRET_PREFIX)
        ]

        return keys

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: The secret to register.

        Raises:
            SecretExistsError: If a secret with this name already exists.
        """
        if self._has_secret(secret.name):
            raise SecretExistsError(
                f"A secret with name '{secret.name}' already exists for this "
                "GitHub repository. If you want to register a new value for "
                f"this secret, please run `zenml secret delete {secret.name}` "
                f"followed by `zenml secret register {secret.name} ...`."
            )

        secret_dict = {
            SECRET_SCHEMA_DICT_KEY: secret.TYPE,
            SECRET_CONTENT_DICT_KEY: secret.content,
        }
        secret_value = string_utils.b64_encode(json.dumps(secret_dict))
        encrypted_secret, public_key_id = self._encrypt_secret(
            secret_value=secret_value
        )
        body = {
            "encrypted_value": encrypted_secret,
            "key_id": public_key_id,
        }

        full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
        self._send_request("PUT", resource=f"/{full_secret_name}", json=body)

    def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
        """Update an existing secret.

        Args:
            secret: The secret to update.

        Raises:
            NotImplementedError: Always, as this functionality is not possible
                using GitHub secrets which doesn't allow us to retrieve the
                secret values outside of a GitHub Actions environment.
        """
        raise NotImplementedError(
            "Updating secrets is not possible with the GitHub secrets manager "
            "as it is not possible to retrieve GitHub secrets values outside "
            "of a GitHub Actions environment."
        )

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: The name of the secret to delete.
        """
        full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
        self._send_request("DELETE", resource=f"/{full_secret_name}")

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        for secret_name in self.get_all_secret_keys(include_prefix=False):
            self.delete_secret(secret_name=secret_name)
post_registration_message: Optional[str] property readonly

Info message regarding GitHub API authentication env variables.

Returns:

Type Description
Optional[str]

The info message.

session: Session property readonly

Session to send requests to the GitHub API.

Returns:

Type Description
Session

Session to use for GitHub API calls.

Exceptions:

Type Description
RuntimeError

If authentication credentials for the GitHub API are not set.

delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    for secret_name in self.get_all_secret_keys(include_prefix=False):
        self.delete_secret(secret_name=secret_name)
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to delete.

required
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: The name of the secret to delete.
    """
    full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
    self._send_request("DELETE", resource=f"/{full_secret_name}")
get_all_secret_keys(self, include_prefix=False)

Get all secret keys.

If we're running inside a GitHub Actions environment, this will return the names of all environment variables starting with a ZenML internal prefix. Otherwise, this will return all GitHub Repository secrets created by ZenML.

Parameters:

Name Type Description Default
include_prefix bool

Whether or not the internal prefix that is used to differentiate ZenML secrets from other GitHub secrets should be included in the returned names.

False

Returns:

Type Description
List[str]

List of all secret keys.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
    """Get all secret keys.

    If we're running inside a GitHub Actions environment, this will return
    the names of all environment variables starting with a ZenML internal
    prefix. Otherwise, this will return all GitHub **Repository** secrets
    created by ZenML.

    Args:
        include_prefix: Whether or not the internal prefix that is used to
            differentiate ZenML secrets from other GitHub secrets should be
            included in the returned names.

    Returns:
        List of all secret keys.
    """
    if inside_github_action_environment():
        potential_secret_keys = list(os.environ)
    else:
        logger.info(
            "Fetching list of secrets for repository %s/%s",
            self.owner,
            self.repository,
        )
        response = self._send_request("GET", params={"per_page": 100})
        potential_secret_keys = [
            secret_dict["name"]
            for secret_dict in response.json()["secrets"]
        ]

    keys = [
        _convert_secret_name(key, remove_prefix=not include_prefix)
        for key in potential_secret_keys
        if key.startswith(GITHUB_SECRET_PREFIX)
    ]

    return keys
get_secret(self, secret_name)

Gets the value of a secret.

This method only works when called from within a GitHub Actions environment.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to get.

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

If a secret with this name doesn't exist.

RuntimeError

If not inside a GitHub Actions environments.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets the value of a secret.

    This method only works when called from within a GitHub Actions
    environment.

    Args:
        secret_name: The name of the secret to get.

    Returns:
        The secret.

    Raises:
        KeyError: If a secret with this name doesn't exist.
        RuntimeError: If not inside a GitHub Actions environments.
    """
    full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
    # Raise a KeyError if the secret doesn't exist. We can do that even
    # if we're not inside a GitHub Actions environment
    if not self._has_secret(secret_name):
        raise KeyError(
            f"Unable to find secret '{secret_name}'. Please check the "
            "GitHub UI to see if a **Repository** secret called "
            f"'{full_secret_name}' exists. (ZenML uses the "
            f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
            "secrets from other GitHub secrets)"
        )

    if not inside_github_action_environment():
        stack_name = Repository().active_stack_name
        commands = [
            f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
            "zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
            "--flavor=local",
            "zenml stack update <NEW_STACK_NAME> "
            "--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
            "zenml stack set <NEW_STACK_NAME>",
            f"zenml secret register {secret_name} ...",
        ]

        raise RuntimeError(
            "Getting GitHub secrets is only possible within a GitHub "
            "Actions workflow. If you need this secret to access "
            "stack components (e.g. your metadata store to fetch pipelines "
            "during the post-execution workflow) locally, you need to "
            "register this secret in a different secrets manager. "
            "You can do this by running the following commands: \n\n"
            + "\n".join(commands)
        )

    # If we're running inside an GitHub Actions environment using the a
    # workflow generated by the GitHub Actions orchestrator, all ZenML
    # secrets stored in the GitHub secrets manager will be accessible as
    # environment variables
    secret_value = cast(str, os.getenv(full_secret_name))

    secret_dict = json.loads(string_utils.b64_decode(secret_value))
    schema_class = SecretSchemaClassRegistry.get_class(
        secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
    )
    secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]

    return schema_class(name=secret_name, **secret_content)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to register.

required

Exceptions:

Type Description
SecretExistsError

If a secret with this name already exists.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: The secret to register.

    Raises:
        SecretExistsError: If a secret with this name already exists.
    """
    if self._has_secret(secret.name):
        raise SecretExistsError(
            f"A secret with name '{secret.name}' already exists for this "
            "GitHub repository. If you want to register a new value for "
            f"this secret, please run `zenml secret delete {secret.name}` "
            f"followed by `zenml secret register {secret.name} ...`."
        )

    secret_dict = {
        SECRET_SCHEMA_DICT_KEY: secret.TYPE,
        SECRET_CONTENT_DICT_KEY: secret.content,
    }
    secret_value = string_utils.b64_encode(json.dumps(secret_dict))
    encrypted_secret, public_key_id = self._encrypt_secret(
        secret_value=secret_value
    )
    body = {
        "encrypted_value": encrypted_secret,
        "key_id": public_key_id,
    }

    full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
    self._send_request("PUT", resource=f"/{full_secret_name}", json=body)
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to update.

required

Exceptions:

Type Description
NotImplementedError

Always, as this functionality is not possible using GitHub secrets which doesn't allow us to retrieve the secret values outside of a GitHub Actions environment.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
    """Update an existing secret.

    Args:
        secret: The secret to update.

    Raises:
        NotImplementedError: Always, as this functionality is not possible
            using GitHub secrets which doesn't allow us to retrieve the
            secret values outside of a GitHub Actions environment.
    """
    raise NotImplementedError(
        "Updating secrets is not possible with the GitHub secrets manager "
        "as it is not possible to retrieve GitHub secrets values outside "
        "of a GitHub Actions environment."
    )
inside_github_action_environment()

Returns if the current code is executing in a GitHub Actions environment.

Returns:

Type Description
bool

True if running in a GitHub Actions environment, False otherwise.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def inside_github_action_environment() -> bool:
    """Returns if the current code is executing in a GitHub Actions environment.

    Returns:
        `True` if running in a GitHub Actions environment, `False` otherwise.
    """
    return os.getenv(ENV_IN_GITHUB_ACTIONS) == "true"

graphviz special

Initialization of the Graphviz integration.

GraphvizIntegration (Integration)

Definition of Graphviz integration for ZenML.

Source code in zenml/integrations/graphviz/__init__.py
class GraphvizIntegration(Integration):
    """Definition of Graphviz integration for ZenML."""

    NAME = GRAPHVIZ
    REQUIREMENTS = ["graphviz>=0.17"]
    SYSTEM_REQUIREMENTS = {"graphviz": "dot"}

visualizers special

Initialization of Graphviz visualizers.

pipeline_run_dag_visualizer

Implementation of the Graphviz pipeline run DAG visualizer.

PipelineRunDagVisualizer (BasePipelineRunVisualizer)

Visualize the lineage of runs in a pipeline.

Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
class PipelineRunDagVisualizer(BasePipelineRunVisualizer):
    """Visualize the lineage of runs in a pipeline."""

    ARTIFACT_DEFAULT_COLOR = "blue"
    ARTIFACT_CACHED_COLOR = "green"
    ARTIFACT_SHAPE = "box"
    ARTIFACT_PREFIX = "artifact_"
    STEP_COLOR = "#431D93"
    STEP_SHAPE = "ellipse"
    STEP_PREFIX = "step_"
    FONT = "Roboto"

    @abstractmethod
    def visualize(
        self, object: PipelineRunView, *args: Any, **kwargs: Any
    ) -> graphviz.Digraph:
        """Creates a pipeline lineage diagram using graphviz.

        Args:
            object: The pipeline run view to visualize.
            *args: Additional arguments to pass to the visualization.
            **kwargs: Additional keyword arguments to pass to the visualization.

        Returns:
            A graphviz digraph object.
        """
        logger.warning(
            "This integration is not completed yet. Results might be unexpected."
        )

        dot = graphviz.Digraph(comment=object.name)

        # link the steps together
        for step in object.steps:
            # add each step as a node
            dot.node(
                self.STEP_PREFIX + str(step.id),
                step.entrypoint_name,
                shape=self.STEP_SHAPE,
            )
            # for each parent of a step, add an edge

            for artifact_name, artifact in step.outputs.items():
                dot.node(
                    self.ARTIFACT_PREFIX + str(artifact.id),
                    f"{artifact_name} \n" f"({artifact._data_type})",
                    shape=self.ARTIFACT_SHAPE,
                )
                dot.edge(
                    self.STEP_PREFIX + str(step.id),
                    self.ARTIFACT_PREFIX + str(artifact.id),
                )

            for artifact_name, artifact in step.inputs.items():
                dot.edge(
                    self.ARTIFACT_PREFIX + str(artifact.id),
                    self.STEP_PREFIX + str(step.id),
                )

        with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
            dot.render(filename=f.name, format="png", view=True, cleanup=True)
        return dot
visualize(self, object, *args, **kwargs)

Creates a pipeline lineage diagram using graphviz.

Parameters:

Name Type Description Default
object PipelineRunView

The pipeline run view to visualize.

required
*args Any

Additional arguments to pass to the visualization.

()
**kwargs Any

Additional keyword arguments to pass to the visualization.

{}

Returns:

Type Description
Digraph

A graphviz digraph object.

Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
@abstractmethod
def visualize(
    self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
    """Creates a pipeline lineage diagram using graphviz.

    Args:
        object: The pipeline run view to visualize.
        *args: Additional arguments to pass to the visualization.
        **kwargs: Additional keyword arguments to pass to the visualization.

    Returns:
        A graphviz digraph object.
    """
    logger.warning(
        "This integration is not completed yet. Results might be unexpected."
    )

    dot = graphviz.Digraph(comment=object.name)

    # link the steps together
    for step in object.steps:
        # add each step as a node
        dot.node(
            self.STEP_PREFIX + str(step.id),
            step.entrypoint_name,
            shape=self.STEP_SHAPE,
        )
        # for each parent of a step, add an edge

        for artifact_name, artifact in step.outputs.items():
            dot.node(
                self.ARTIFACT_PREFIX + str(artifact.id),
                f"{artifact_name} \n" f"({artifact._data_type})",
                shape=self.ARTIFACT_SHAPE,
            )
            dot.edge(
                self.STEP_PREFIX + str(step.id),
                self.ARTIFACT_PREFIX + str(artifact.id),
            )

        for artifact_name, artifact in step.inputs.items():
            dot.edge(
                self.ARTIFACT_PREFIX + str(artifact.id),
                self.STEP_PREFIX + str(step.id),
            )

    with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
        dot.render(filename=f.name, format="png", view=True, cleanup=True)
    return dot

great_expectations special

Great Expectation integration for ZenML.

The Great Expectations integration enables you to use Great Expectations as a way of profiling and validating your data.

GreatExpectationsIntegration (Integration)

Definition of Great Expectations integration for ZenML.

Source code in zenml/integrations/great_expectations/__init__.py
class GreatExpectationsIntegration(Integration):
    """Definition of Great Expectations integration for ZenML."""

    NAME = GREAT_EXPECTATIONS
    REQUIREMENTS = [
        "great-expectations~=0.15.11",
    ]

    @staticmethod
    def activate() -> None:
        """Activate the Great Expectations integration."""
        from zenml.integrations.great_expectations import materializers  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Great Expectations integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
                source="zenml.integrations.great_expectations.data_validators.GreatExpectationsDataValidator",
                type=StackComponentType.DATA_VALIDATOR,
                integration=cls.NAME,
            ),
        ]
activate() staticmethod

Activate the Great Expectations integration.

Source code in zenml/integrations/great_expectations/__init__.py
@staticmethod
def activate() -> None:
    """Activate the Great Expectations integration."""
    from zenml.integrations.great_expectations import materializers  # noqa
flavors() classmethod

Declare the stack component flavors for the Great Expectations integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/great_expectations/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Great Expectations integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
            source="zenml.integrations.great_expectations.data_validators.GreatExpectationsDataValidator",
            type=StackComponentType.DATA_VALIDATOR,
            integration=cls.NAME,
        ),
    ]

data_validators special

Initialization of the Great Expectations data validator for ZenML.

ge_data_validator

Implementation of the Great Expectations data validator.

GreatExpectationsDataValidator (BaseDataValidator) pydantic-model

Great Expectations data validator stack component.

Attributes:

Name Type Description
context_root_dir Optional[str]

location of an already initialized Great Expectations data context. If configured, the data validator will only be usable with local orchestrators.

context_config Optional[Dict[str, Any]]

in-line Great Expectations data context configuration.

configure_zenml_stores bool

if set, ZenML will automatically configure stores that use the Artifact Store as a backend. If neither context_root_dir nor context_config are set, this is the default behavior.

configure_local_docs bool

configure a local data docs site where Great Expectations docs are generated and can be visualized locally.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
class GreatExpectationsDataValidator(BaseDataValidator):
    """Great Expectations data validator stack component.

    Attributes:
        context_root_dir: location of an already initialized Great Expectations
            data context. If configured, the data validator will only be usable
            with local orchestrators.
        context_config: in-line Great Expectations data context configuration.
        configure_zenml_stores: if set, ZenML will automatically configure
            stores that use the Artifact Store as a backend. If neither
            `context_root_dir` nor `context_config` are set, this is the default
            behavior.
        configure_local_docs: configure a local data docs site where Great
            Expectations docs are generated and can be visualized locally.
    """

    context_root_dir: Optional[str] = None
    context_config: Optional[Dict[str, Any]] = None
    configure_zenml_stores: bool = False
    configure_local_docs: bool = True
    _context: BaseDataContext = None

    # Class Configuration
    FLAVOR: ClassVar[str] = GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR

    @validator("context_root_dir")
    def _ensure_valid_context_root_dir(
        cls, context_root_dir: Optional[str] = None
    ) -> Optional[str]:
        """Ensures that the root directory is an absolute path and points to an existing path.

        Args:
            context_root_dir: The context_root_dir value to validate.

        Returns:
            The context_root_dir if it is valid.

        Raises:
            ValueError: If the context_root_dir is not valid.
        """
        if context_root_dir:
            context_root_dir = os.path.abspath(context_root_dir)
            if not fileio.exists(context_root_dir):
                raise ValueError(
                    f"The Great Expectations context_root_dir value doesn't "
                    f"point to an existing data context path: {context_root_dir}"
                )
        return context_root_dir

    @root_validator(pre=True)
    def _convert_context_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Converts context_config from JSON/YAML string format to a dict.

        Args:
            values: Values passed to the object constructor

        Returns:
            Values passed to the object constructor

        Raises:
            ValueError: If the context_config value is not a valid JSON/YAML or
                if the GE configuration extracted from it fails GE validation.
        """
        context_config = values.get("context_config")
        if context_config and not isinstance(context_config, dict):
            try:
                context_config_dict = yaml.safe_load(context_config)
            except yaml.parser.ParserError as e:
                raise ValueError(
                    f"Malformed `context_config` value. Only JSON and YAML formats "
                    f"are supported: {str(e)}"
                )
            try:
                context_config = DataContextConfig(**context_config_dict)
                BaseDataContext(project_config=context_config)
            except Exception as e:
                raise ValueError(f"Invalid `context_config` value: {str(e)}")

            values["context_config"] = context_config_dict
        return values

    @classmethod
    def get_data_context(cls) -> BaseDataContext:
        """Get the Great Expectations data context managed by ZenML.

        Call this method to retrieve the data context managed by ZenML
        through the active Great Expectations data validator stack component.

        Returns:
            A Great Expectations data context managed by ZenML as configured
            through the active data validator stack component.
        """
        data_validator = cast(
            "GreatExpectationsDataValidator", cls.get_active_data_validator()
        )
        return data_validator.data_context

    @property
    def local_path(self) -> Optional[str]:
        """Return a local path where this component stores information.

        If an existing local GE data context is used, it is
        interpreted as a local path that needs to be accessible in
        all runtime environments.

        Returns:
            The local path where this component stores information.
        """
        return self.context_root_dir

    def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
        """Generate a Great Expectations store configuration.

        Args:
            class_name: The store class name
            prefix: The path prefix for the ZenML store configuration

        Returns:
            A dictionary with the GE store configuration.
        """
        return {
            "class_name": class_name,
            "store_backend": {
                "module_name": ZenMLArtifactStoreBackend.__module__,
                "class_name": ZenMLArtifactStoreBackend.__name__,
                "prefix": f"{str(self.uuid)}/{prefix}",
            },
        }

    def get_data_docs_config(
        self, prefix: str, local: bool = False
    ) -> Dict[str, Any]:
        """Generate Great Expectations data docs configuration.

        Args:
            prefix: The path prefix for the ZenML data docs configuration
            local: Whether the data docs site is local or remote.

        Returns:
            A dictionary with the GE data docs site configuration.
        """
        if local:
            store_backend = {
                "class_name": "TupleFilesystemStoreBackend",
                "base_directory": f"{self.root_directory}/{prefix}",
            }
        else:
            store_backend = {
                "module_name": ZenMLArtifactStoreBackend.__module__,
                "class_name": ZenMLArtifactStoreBackend.__name__,
                "prefix": f"{str(self.uuid)}/{prefix}",
            }

        return {
            "class_name": "SiteBuilder",
            "store_backend": store_backend,
            "site_index_builder": {
                "class_name": "DefaultSiteIndexBuilder",
            },
        }

    @property
    def data_context(self) -> BaseDataContext:
        """Returns the Great Expectations data context configured for this component.

        Returns:
            The Great Expectations data context configured for this component.
        """
        if not self._context:
            expectations_store_name = "zenml_expectations_store"
            validations_store_name = "zenml_validations_store"
            checkpoint_store_name = "zenml_checkpoint_store"
            profiler_store_name = "zenml_profiler_store"
            evaluation_parameter_store_name = "evaluation_parameter_store"

            zenml_context_config = dict(
                stores={
                    expectations_store_name: self.get_store_config(
                        "ExpectationsStore", "expectations"
                    ),
                    validations_store_name: self.get_store_config(
                        "ValidationsStore", "validations"
                    ),
                    checkpoint_store_name: self.get_store_config(
                        "CheckpointStore", "checkpoints"
                    ),
                    profiler_store_name: self.get_store_config(
                        "ProfilerStore", "profilers"
                    ),
                    evaluation_parameter_store_name: {
                        "class_name": "EvaluationParameterStore"
                    },
                },
                expectations_store_name=expectations_store_name,
                validations_store_name=validations_store_name,
                checkpoint_store_name=checkpoint_store_name,
                profiler_store_name=profiler_store_name,
                evaluation_parameter_store_name=evaluation_parameter_store_name,
                data_docs_sites={
                    "zenml_artifact_store": self.get_data_docs_config(
                        "data_docs"
                    )
                },
            )

            configure_zenml_stores = self.configure_zenml_stores
            if self.context_root_dir:
                # initialize the local data context, if a local path was
                # configured
                self._context = DataContext(self.context_root_dir)
            else:
                # create an in-memory data context configuration that is not
                # backed by a local YAML file (see https://docs.greatexpectations.io/docs/guides/setup/configuring_data_contexts/how_to_instantiate_a_data_context_without_a_yml_file/).
                if self.context_config:
                    context_config = DataContextConfig(**self.context_config)
                else:
                    context_config = DataContextConfig(**zenml_context_config)
                    # skip adding the stores after initialization, as they are
                    # already baked in the initial configuration
                    configure_zenml_stores = False
                self._context = BaseDataContext(project_config=context_config)

            if configure_zenml_stores:
                self._context.config.expectations_store_name = (
                    expectations_store_name
                )
                self._context.config.validations_store_name = (
                    validations_store_name
                )
                self._context.config.checkpoint_store_name = (
                    checkpoint_store_name
                )
                self._context.config.profiler_store_name = profiler_store_name
                self._context.config.evaluation_parameter_store_name = (
                    evaluation_parameter_store_name
                )
                for store_name, store_config in zenml_context_config[  # type: ignore[attr-defined]
                    "stores"
                ].items():
                    self._context.add_store(
                        store_name=store_name,
                        store_config=store_config,
                    )
                for site_name, site_config in zenml_context_config[  # type: ignore[attr-defined]
                    "data_docs_sites"
                ].items():
                    self._context.config.data_docs_sites[
                        site_name
                    ] = site_config

            if self.configure_local_docs:

                repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
                artifact_store = repo.active_stack.artifact_store
                if artifact_store.FLAVOR != "local":
                    self._context.config.data_docs_sites[
                        "zenml_local"
                    ] = self.get_data_docs_config("data_docs", local=True)

        return self._context

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all local files concerning this data validator.

        Returns:
            Path to the root directory.
        """
        path = os.path.join(
            io_utils.get_global_config_directory(),
            self.FLAVOR,
            str(self.uuid),
        )

        if not os.path.exists(path):
            fileio.makedirs(path)

        return path

    def data_profiling(
        self,
        dataset: pd.DataFrame,
        comparison_dataset: Optional[Any] = None,
        profile_list: Optional[Sequence[str]] = None,
        expectation_suite_name: Optional[str] = None,
        data_asset_name: Optional[str] = None,
        profiler_kwargs: Optional[Dict[str, Any]] = None,
        overwrite_existing_suite: bool = True,
        **kwargs: Any,
    ) -> ExpectationSuite:
        """Infer a Great Expectation Expectation Suite from a given dataset.

        This Great Expectations specific data profiling method implementation
        builds an Expectation Suite automatically by running a
        UserConfigurableProfiler on an input dataset [as covered in the official
        GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).

        Args:
            dataset: The dataset from which the expectation suite will be
                inferred.
            comparison_dataset: Optional dataset used to generate data
                comparison (i.e. data drift) profiles. Not supported by the
                Great Expectation data validator.
            profile_list: Optional list identifying the categories of data
                profiles to be generated. Not supported by the Great Expectation
                data validator.
            expectation_suite_name: The name of the expectation suite to create
                or update. If not supplied, a unique name will be generated from
                the current pipeline and step name, if running in the context of
                a pipeline step.
            data_asset_name: The name of the data asset to use to identify the
                dataset in the Great Expectations docs.
            profiler_kwargs: A dictionary of custom keyword arguments to pass to
                the profiler.
            overwrite_existing_suite: Whether to overwrite an existing
                expectation suite, if one exists with that name.
            kwargs: Additional keyword arguments (unused).

        Returns:
            The inferred Expectation Suite.

        Raises:
            ValueError: if an `expectation_suite_name` value is not supplied and
                a name for the expectation suite cannot be generated from the
                current step name and pipeline name.
        """
        context = self.data_context

        if comparison_dataset is not None:
            logger.warning(
                "A comparison dataset is not required by Great Expectations "
                "to do data profiling. Silently ignoring the supplied dataset "
            )

        if not expectation_suite_name:
            try:
                # get pipeline name and step name
                step_env = cast(
                    StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
                )
                pipeline_name = step_env.pipeline_name
                step_name = step_env.step_name
                expectation_suite_name = f"{pipeline_name}_{step_name}"
            except KeyError:
                raise ValueError(
                    "A expectation suite name is required when not running in "
                    "the context of a pipeline step."
                )

        suite_exists = False
        if context.expectations_store.has_key(  # noqa
            ExpectationSuiteIdentifier(expectation_suite_name)
        ):
            suite_exists = True
            suite = context.get_expectation_suite(expectation_suite_name)
            if not overwrite_existing_suite:
                logger.info(
                    f"Expectation Suite `{expectation_suite_name}` "
                    f"already exists and `overwrite_existing_suite` is not set "
                    f"in the step configuration. Skipping re-running the "
                    f"profiler."
                )
                return suite

        batch_request = create_batch_request(context, dataset, data_asset_name)

        try:
            if suite_exists:
                validator = context.get_validator(
                    batch_request=batch_request,
                    expectation_suite_name=expectation_suite_name,
                )
            else:
                validator = context.get_validator(
                    batch_request=batch_request,
                    create_expectation_suite_with_name=expectation_suite_name,
                )

            profiler = UserConfigurableProfiler(
                profile_dataset=validator, **profiler_kwargs
            )

            suite = profiler.build_suite()
            context.save_expectation_suite(
                expectation_suite=suite,
                expectation_suite_name=expectation_suite_name,
            )

            context.build_data_docs()
        finally:
            context.delete_datasource(batch_request.datasource_name)

        return suite

    def data_validation(
        self,
        dataset: pd.DataFrame,
        comparison_dataset: Optional[Any] = None,
        check_list: Optional[Sequence[str]] = None,
        expectation_suite_name: Optional[str] = None,
        data_asset_name: Optional[str] = None,
        action_list: Optional[List[Dict[str, Any]]] = None,
        **kwargs: Any,
    ) -> CheckpointResult:
        """Great Expectations data validation.

        This Great Expectations specific data validation method
        implementation validates an input dataset against an Expectation Suite
        (the GE definition of a profile) [as covered in the official GE
        documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).

        Args:
            dataset: The dataset to validate.
            comparison_dataset: Optional dataset used to run data
                comparison (i.e. data drift) checks. Not supported by the
                Great Expectation data validator.
            check_list: Optional list identifying the data validation checks to
                be performed. Not supported by the Great Expectations data
                validator.
            expectation_suite_name: The name of the expectation suite to use to
                validate the dataset. A value must be provided.
            data_asset_name: The name of the data asset to use to identify the
                dataset in the Great Expectations docs.
            action_list: A list of additional Great Expectations actions to run after
                the validation check.
            kwargs: Additional keyword arguments (unused).

        Returns:
            The Great Expectations validation (checkpoint) result.

        Raises:
            ValueError: if the `expectation_suite_name` argument is omitted.
        """
        if not expectation_suite_name:
            raise ValueError("Missing expectation_suite_name argument value.")

        if comparison_dataset is not None:
            logger.warning(
                "A comparison dataset is not required by Great Expectations "
                "to do data validation. Silently ignoring the supplied dataset "
            )

        try:
            # get pipeline name, step name and run id
            step_env = cast(
                StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
            )
            run_id = step_env.pipeline_run_id
            step_name = step_env.step_name
        except KeyError:
            # if not running inside a pipeline step, use random values
            run_id = f"pipeline_{random_str(5)}"
            step_name = f"step_{random_str(5)}"

        context = self.data_context

        checkpoint_name = f"{run_id}_{step_name}"

        batch_request = create_batch_request(context, dataset, data_asset_name)

        action_list = action_list or [
            {
                "name": "store_validation_result",
                "action": {"class_name": "StoreValidationResultAction"},
            },
            {
                "name": "store_evaluation_params",
                "action": {"class_name": "StoreEvaluationParametersAction"},
            },
            {
                "name": "update_data_docs",
                "action": {"class_name": "UpdateDataDocsAction"},
            },
        ]

        checkpoint_config = {
            "name": checkpoint_name,
            "run_name_template": f"{run_id}",
            "config_version": 1,
            "class_name": "Checkpoint",
            "expectation_suite_name": expectation_suite_name,
            "action_list": action_list,
        }
        context.add_checkpoint(**checkpoint_config)

        try:
            results = context.run_checkpoint(
                checkpoint_name=checkpoint_name,
                validations=[{"batch_request": batch_request}],
            )
        finally:
            context.delete_datasource(batch_request.datasource_name)
            context.delete_checkpoint(checkpoint_name)

        return results
data_context: BaseDataContext property readonly

Returns the Great Expectations data context configured for this component.

Returns:

Type Description
BaseDataContext

The Great Expectations data context configured for this component.

local_path: Optional[str] property readonly

Return a local path where this component stores information.

If an existing local GE data context is used, it is interpreted as a local path that needs to be accessible in all runtime environments.

Returns:

Type Description
Optional[str]

The local path where this component stores information.

root_directory: str property readonly

Returns path to the root directory for all local files concerning this data validator.

Returns:

Type Description
str

Path to the root directory.

data_profiling(self, dataset, comparison_dataset=None, profile_list=None, expectation_suite_name=None, data_asset_name=None, profiler_kwargs=None, overwrite_existing_suite=True, **kwargs)

Infer a Great Expectation Expectation Suite from a given dataset.

This Great Expectations specific data profiling method implementation builds an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.

Parameters:

Name Type Description Default
dataset DataFrame

The dataset from which the expectation suite will be inferred.

required
comparison_dataset Optional[Any]

Optional dataset used to generate data comparison (i.e. data drift) profiles. Not supported by the Great Expectation data validator.

None
profile_list Optional[Sequence[str]]

Optional list identifying the categories of data profiles to be generated. Not supported by the Great Expectation data validator.

None
expectation_suite_name Optional[str]

The name of the expectation suite to create or update. If not supplied, a unique name will be generated from the current pipeline and step name, if running in the context of a pipeline step.

None
data_asset_name Optional[str]

The name of the data asset to use to identify the dataset in the Great Expectations docs.

None
profiler_kwargs Optional[Dict[str, Any]]

A dictionary of custom keyword arguments to pass to the profiler.

None
overwrite_existing_suite bool

Whether to overwrite an existing expectation suite, if one exists with that name.

True
kwargs Any

Additional keyword arguments (unused).

{}

Returns:

Type Description
ExpectationSuite

The inferred Expectation Suite.

Exceptions:

Type Description
ValueError

if an expectation_suite_name value is not supplied and a name for the expectation suite cannot be generated from the current step name and pipeline name.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def data_profiling(
    self,
    dataset: pd.DataFrame,
    comparison_dataset: Optional[Any] = None,
    profile_list: Optional[Sequence[str]] = None,
    expectation_suite_name: Optional[str] = None,
    data_asset_name: Optional[str] = None,
    profiler_kwargs: Optional[Dict[str, Any]] = None,
    overwrite_existing_suite: bool = True,
    **kwargs: Any,
) -> ExpectationSuite:
    """Infer a Great Expectation Expectation Suite from a given dataset.

    This Great Expectations specific data profiling method implementation
    builds an Expectation Suite automatically by running a
    UserConfigurableProfiler on an input dataset [as covered in the official
    GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).

    Args:
        dataset: The dataset from which the expectation suite will be
            inferred.
        comparison_dataset: Optional dataset used to generate data
            comparison (i.e. data drift) profiles. Not supported by the
            Great Expectation data validator.
        profile_list: Optional list identifying the categories of data
            profiles to be generated. Not supported by the Great Expectation
            data validator.
        expectation_suite_name: The name of the expectation suite to create
            or update. If not supplied, a unique name will be generated from
            the current pipeline and step name, if running in the context of
            a pipeline step.
        data_asset_name: The name of the data asset to use to identify the
            dataset in the Great Expectations docs.
        profiler_kwargs: A dictionary of custom keyword arguments to pass to
            the profiler.
        overwrite_existing_suite: Whether to overwrite an existing
            expectation suite, if one exists with that name.
        kwargs: Additional keyword arguments (unused).

    Returns:
        The inferred Expectation Suite.

    Raises:
        ValueError: if an `expectation_suite_name` value is not supplied and
            a name for the expectation suite cannot be generated from the
            current step name and pipeline name.
    """
    context = self.data_context

    if comparison_dataset is not None:
        logger.warning(
            "A comparison dataset is not required by Great Expectations "
            "to do data profiling. Silently ignoring the supplied dataset "
        )

    if not expectation_suite_name:
        try:
            # get pipeline name and step name
            step_env = cast(
                StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
            )
            pipeline_name = step_env.pipeline_name
            step_name = step_env.step_name
            expectation_suite_name = f"{pipeline_name}_{step_name}"
        except KeyError:
            raise ValueError(
                "A expectation suite name is required when not running in "
                "the context of a pipeline step."
            )

    suite_exists = False
    if context.expectations_store.has_key(  # noqa
        ExpectationSuiteIdentifier(expectation_suite_name)
    ):
        suite_exists = True
        suite = context.get_expectation_suite(expectation_suite_name)
        if not overwrite_existing_suite:
            logger.info(
                f"Expectation Suite `{expectation_suite_name}` "
                f"already exists and `overwrite_existing_suite` is not set "
                f"in the step configuration. Skipping re-running the "
                f"profiler."
            )
            return suite

    batch_request = create_batch_request(context, dataset, data_asset_name)

    try:
        if suite_exists:
            validator = context.get_validator(
                batch_request=batch_request,
                expectation_suite_name=expectation_suite_name,
            )
        else:
            validator = context.get_validator(
                batch_request=batch_request,
                create_expectation_suite_with_name=expectation_suite_name,
            )

        profiler = UserConfigurableProfiler(
            profile_dataset=validator, **profiler_kwargs
        )

        suite = profiler.build_suite()
        context.save_expectation_suite(
            expectation_suite=suite,
            expectation_suite_name=expectation_suite_name,
        )

        context.build_data_docs()
    finally:
        context.delete_datasource(batch_request.datasource_name)

    return suite
data_validation(self, dataset, comparison_dataset=None, check_list=None, expectation_suite_name=None, data_asset_name=None, action_list=None, **kwargs)

Great Expectations data validation.

This Great Expectations specific data validation method implementation validates an input dataset against an Expectation Suite (the GE definition of a profile) as covered in the official GE documentation.

Parameters:

Name Type Description Default
dataset DataFrame

The dataset to validate.

required
comparison_dataset Optional[Any]

Optional dataset used to run data comparison (i.e. data drift) checks. Not supported by the Great Expectation data validator.

None
check_list Optional[Sequence[str]]

Optional list identifying the data validation checks to be performed. Not supported by the Great Expectations data validator.

None
expectation_suite_name Optional[str]

The name of the expectation suite to use to validate the dataset. A value must be provided.

None
data_asset_name Optional[str]

The name of the data asset to use to identify the dataset in the Great Expectations docs.

None
action_list Optional[List[Dict[str, Any]]]

A list of additional Great Expectations actions to run after the validation check.

None
kwargs Any

Additional keyword arguments (unused).

{}

Returns:

Type Description
CheckpointResult

The Great Expectations validation (checkpoint) result.

Exceptions:

Type Description
ValueError

if the expectation_suite_name argument is omitted.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def data_validation(
    self,
    dataset: pd.DataFrame,
    comparison_dataset: Optional[Any] = None,
    check_list: Optional[Sequence[str]] = None,
    expectation_suite_name: Optional[str] = None,
    data_asset_name: Optional[str] = None,
    action_list: Optional[List[Dict[str, Any]]] = None,
    **kwargs: Any,
) -> CheckpointResult:
    """Great Expectations data validation.

    This Great Expectations specific data validation method
    implementation validates an input dataset against an Expectation Suite
    (the GE definition of a profile) [as covered in the official GE
    documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).

    Args:
        dataset: The dataset to validate.
        comparison_dataset: Optional dataset used to run data
            comparison (i.e. data drift) checks. Not supported by the
            Great Expectation data validator.
        check_list: Optional list identifying the data validation checks to
            be performed. Not supported by the Great Expectations data
            validator.
        expectation_suite_name: The name of the expectation suite to use to
            validate the dataset. A value must be provided.
        data_asset_name: The name of the data asset to use to identify the
            dataset in the Great Expectations docs.
        action_list: A list of additional Great Expectations actions to run after
            the validation check.
        kwargs: Additional keyword arguments (unused).

    Returns:
        The Great Expectations validation (checkpoint) result.

    Raises:
        ValueError: if the `expectation_suite_name` argument is omitted.
    """
    if not expectation_suite_name:
        raise ValueError("Missing expectation_suite_name argument value.")

    if comparison_dataset is not None:
        logger.warning(
            "A comparison dataset is not required by Great Expectations "
            "to do data validation. Silently ignoring the supplied dataset "
        )

    try:
        # get pipeline name, step name and run id
        step_env = cast(
            StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
        )
        run_id = step_env.pipeline_run_id
        step_name = step_env.step_name
    except KeyError:
        # if not running inside a pipeline step, use random values
        run_id = f"pipeline_{random_str(5)}"
        step_name = f"step_{random_str(5)}"

    context = self.data_context

    checkpoint_name = f"{run_id}_{step_name}"

    batch_request = create_batch_request(context, dataset, data_asset_name)

    action_list = action_list or [
        {
            "name": "store_validation_result",
            "action": {"class_name": "StoreValidationResultAction"},
        },
        {
            "name": "store_evaluation_params",
            "action": {"class_name": "StoreEvaluationParametersAction"},
        },
        {
            "name": "update_data_docs",
            "action": {"class_name": "UpdateDataDocsAction"},
        },
    ]

    checkpoint_config = {
        "name": checkpoint_name,
        "run_name_template": f"{run_id}",
        "config_version": 1,
        "class_name": "Checkpoint",
        "expectation_suite_name": expectation_suite_name,
        "action_list": action_list,
    }
    context.add_checkpoint(**checkpoint_config)

    try:
        results = context.run_checkpoint(
            checkpoint_name=checkpoint_name,
            validations=[{"batch_request": batch_request}],
        )
    finally:
        context.delete_datasource(batch_request.datasource_name)
        context.delete_checkpoint(checkpoint_name)

    return results
get_data_context() classmethod

Get the Great Expectations data context managed by ZenML.

Call this method to retrieve the data context managed by ZenML through the active Great Expectations data validator stack component.

Returns:

Type Description
BaseDataContext

A Great Expectations data context managed by ZenML as configured through the active data validator stack component.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
@classmethod
def get_data_context(cls) -> BaseDataContext:
    """Get the Great Expectations data context managed by ZenML.

    Call this method to retrieve the data context managed by ZenML
    through the active Great Expectations data validator stack component.

    Returns:
        A Great Expectations data context managed by ZenML as configured
        through the active data validator stack component.
    """
    data_validator = cast(
        "GreatExpectationsDataValidator", cls.get_active_data_validator()
    )
    return data_validator.data_context
get_data_docs_config(self, prefix, local=False)

Generate Great Expectations data docs configuration.

Parameters:

Name Type Description Default
prefix str

The path prefix for the ZenML data docs configuration

required
local bool

Whether the data docs site is local or remote.

False

Returns:

Type Description
Dict[str, Any]

A dictionary with the GE data docs site configuration.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_data_docs_config(
    self, prefix: str, local: bool = False
) -> Dict[str, Any]:
    """Generate Great Expectations data docs configuration.

    Args:
        prefix: The path prefix for the ZenML data docs configuration
        local: Whether the data docs site is local or remote.

    Returns:
        A dictionary with the GE data docs site configuration.
    """
    if local:
        store_backend = {
            "class_name": "TupleFilesystemStoreBackend",
            "base_directory": f"{self.root_directory}/{prefix}",
        }
    else:
        store_backend = {
            "module_name": ZenMLArtifactStoreBackend.__module__,
            "class_name": ZenMLArtifactStoreBackend.__name__,
            "prefix": f"{str(self.uuid)}/{prefix}",
        }

    return {
        "class_name": "SiteBuilder",
        "store_backend": store_backend,
        "site_index_builder": {
            "class_name": "DefaultSiteIndexBuilder",
        },
    }
get_store_config(self, class_name, prefix)

Generate a Great Expectations store configuration.

Parameters:

Name Type Description Default
class_name str

The store class name

required
prefix str

The path prefix for the ZenML store configuration

required

Returns:

Type Description
Dict[str, Any]

A dictionary with the GE store configuration.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
    """Generate a Great Expectations store configuration.

    Args:
        class_name: The store class name
        prefix: The path prefix for the ZenML store configuration

    Returns:
        A dictionary with the GE store configuration.
    """
    return {
        "class_name": class_name,
        "store_backend": {
            "module_name": ZenMLArtifactStoreBackend.__module__,
            "class_name": ZenMLArtifactStoreBackend.__name__,
            "prefix": f"{str(self.uuid)}/{prefix}",
        },
    }

ge_store_backend

Great Expectations store plugin for ZenML.

ZenMLArtifactStoreBackend (TupleStoreBackend)

Great Expectations store backend that uses the active ZenML Artifact Store as a store.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
class ZenMLArtifactStoreBackend(TupleStoreBackend):  # type: ignore[misc]
    """Great Expectations store backend that uses the active ZenML Artifact Store as a store."""

    def __init__(
        self,
        prefix: str = "",
        **kwargs: Any,
    ) -> None:
        """Create a Great Expectations ZenML store backend instance.

        Args:
            prefix: Subpath prefix to use for this store backend.
            kwargs: Additional keyword arguments passed by the Great Expectations
                core. These are transparently passed to the `TupleStoreBackend`
                constructor.
        """
        super().__init__(**kwargs)

        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        artifact_store = repo.active_stack.artifact_store
        self.root_path = os.path.join(artifact_store.path, "great_expectations")

        # extract the protocol used in the artifact store root path
        protocols = [
            scheme
            for scheme in artifact_store.SUPPORTED_SCHEMES
            if self.root_path.startswith(scheme)
        ]
        if protocols:
            self.proto = protocols[0]
        else:
            self.proto = ""

        if prefix:
            if self.platform_specific_separator:
                prefix = prefix.strip(os.sep)
            prefix = prefix.strip("/")
        self.prefix = prefix

        # Initialize with store_backend_id if not part of an HTMLSiteStore
        if not self._suppress_store_backend_id:
            _ = self.store_backend_id

        self._config = {
            "prefix": prefix,
            "module_name": self.__class__.__module__,
            "class_name": self.__class__.__name__,
        }
        self._config.update(kwargs)
        filter_properties_dict(
            properties=self._config, clean_falsy=True, inplace=True
        )

    def _build_object_path(
        self, key: Tuple[str, ...], is_prefix: bool = False
    ) -> str:
        """Build a filepath corresponding to an object key.

        Args:
            key: Great Expectation object key.
            is_prefix: If True, the key will be interpreted as a prefix instead
                of a full key identifier.

        Returns:
            The file path pointing to where the object is stored.
        """
        if not isinstance(key, tuple):
            key = key.to_tuple()  # type: ignore[attr-defined]
        if not is_prefix:
            object_relative_path = self._convert_key_to_filepath(key)
        elif key:
            object_relative_path = os.path.join(*key)
        else:
            object_relative_path = ""
        if self.prefix:
            object_key = os.path.join(self.prefix, object_relative_path)
        else:
            object_key = object_relative_path
        return os.path.join(self.root_path, object_key)

    def _get(self, key: Tuple[str, ...]) -> str:
        """Get the value of an object from the store.

        Args:
            key: object key identifier.

        Raises:
            InvalidKeyError: if the key doesn't point to an existing object.

        Returns:
            str: the object's contents
        """
        filepath: str = self._build_object_path(key)
        if fileio.exists(filepath):
            contents = io_utils.read_file_contents_as_string(filepath).rstrip(
                "\n"
            )
        else:
            raise InvalidKeyError(
                f"Unable to retrieve object from {self.__class__.__name__} with "
                f"the following Key: {str(filepath)}"
            )
        return contents

    def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str:
        """Set the value of an object in the store.

        Args:
            key: object key identifier.
            value: object value to set.
            kwargs: additional keyword arguments (ignored).

        Returns:
            The file path where the object was stored.
        """
        filepath: str = self._build_object_path(key)
        if not io_utils.is_remote(filepath):
            parent_dir = str(Path(filepath).parent)
            os.makedirs(parent_dir, exist_ok=True)

        with fileio.open(filepath, "wb") as outfile:
            if isinstance(value, str):
                outfile.write(value.encode("utf-8"))
            else:
                outfile.write(value)
        return filepath

    def _move(
        self,
        source_key: Tuple[str, ...],
        dest_key: Tuple[str, ...],
        **kwargs: Any,
    ) -> None:
        """Associate an object with a different key in the store.

        Args:
            source_key: current object key identifier.
            dest_key: new object key identifier.
            kwargs: additional keyword arguments (ignored).
        """
        source_path = self._build_object_path(source_key)
        dest_path = self._build_object_path(dest_key)

        if fileio.exists(source_path):
            if not io_utils.is_remote(dest_path):
                parent_dir = str(Path(dest_path).parent)
                os.makedirs(parent_dir, exist_ok=True)
            fileio.rename(source_path, dest_path, overwrite=True)

    def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
        """List the keys of all objects identified by a partial key.

        Args:
            prefix: partial object key identifier.

        Returns:
            List of keys identifying all objects present in the store that
            match the input partial key.
        """
        key_list = []
        list_path = self._build_object_path(prefix, is_prefix=True)
        root_path = self._build_object_path(tuple(), is_prefix=True)
        for root, dirs, files in fileio.walk(list_path):
            for file_ in files:
                filepath = os.path.relpath(
                    os.path.join(str(root), str(file_)), root_path
                )

                if self.filepath_prefix and not filepath.startswith(
                    self.filepath_prefix
                ):
                    continue
                elif self.filepath_suffix and not filepath.endswith(
                    self.filepath_suffix
                ):
                    continue
                key = self._convert_filepath_to_key(filepath)
                if key and not self.is_ignored_key(key):
                    key_list.append(key)
        return key_list

    def remove_key(self, key: Tuple[str, ...]) -> bool:
        """Delete an object from the store.

        Args:
            key: object key identifier.

        Returns:
            True if the object existed in the store and was removed, otherwise
            False.
        """
        filepath: str = self._build_object_path(key)

        if fileio.exists(filepath):
            fileio.remove(filepath)
            if not io_utils.is_remote(filepath):
                parent_dir = str(Path(filepath).parent)
                self.rrmdir(self.root_path, str(parent_dir))
            return True
        return False

    def _has_key(self, key: Tuple[str, ...]) -> bool:
        """Check if an object is present in the store.

        Args:
            key: object key identifier.

        Returns:
            True if the object is present in the store, otherwise False.
        """
        filepath: str = self._build_object_path(key)
        result = fileio.exists(filepath)
        return result

    def get_url_for_key(
        self, key: Tuple[str, ...], protocol: Optional[str] = None
    ) -> str:
        """Get the URL of an object in the store.

        Args:
            key: object key identifier.
            protocol: optional protocol to use instead of the store protocol.

        Returns:
            The URL of the object in the store.
        """
        filepath = self._build_object_path(key)
        if not protocol and not io_utils.is_remote(filepath):
            protocol = "file:"
        if protocol:
            filepath = filepath.replace(self.proto, f"{protocol}//", 1)

        return filepath

    def get_public_url_for_key(
        self, key: str, protocol: Optional[str] = None
    ) -> str:
        """Get the public URL of an object in the store.

        Args:
            key: object key identifier.
            protocol: optional protocol to use instead of the store protocol.

        Returns:
            The public URL where the object can be accessed.

        Raises:
            StoreBackendError: if a `base_public_path` attribute was not
                configured for the store.
        """
        if not self.base_public_path:
            raise StoreBackendError(
                f"Error: No base_public_path was configured! A public URL was "
                f"requested but `base_public_path` was not configured for the "
                f"{self.__class__.__name__}"
            )
        filepath = self._convert_key_to_filepath(key)
        public_url = self.base_public_path + filepath.replace(self.proto, "")
        return cast(str, public_url)

    @staticmethod
    def rrmdir(start_path: str, end_path: str) -> None:
        """Recursively removes empty dirs between start_path and end_path inclusive.

        Args:
            start_path: Directory to use as a starting point.
            end_path: Directory to use as a destination point.
        """
        while not os.listdir(end_path) and start_path != end_path:
            os.rmdir(end_path)
            end_path = os.path.dirname(end_path)

    @property
    def config(self) -> Dict[str, Any]:
        """Get the store configuration.

        Returns:
            The store configuration.
        """
        return self._config
config: Dict[str, Any] property readonly

Get the store configuration.

Returns:

Type Description
Dict[str, Any]

The store configuration.

__init__(self, prefix='', **kwargs) special

Create a Great Expectations ZenML store backend instance.

Parameters:

Name Type Description Default
prefix str

Subpath prefix to use for this store backend.

''
kwargs Any

Additional keyword arguments passed by the Great Expectations core. These are transparently passed to the TupleStoreBackend constructor.

{}
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def __init__(
    self,
    prefix: str = "",
    **kwargs: Any,
) -> None:
    """Create a Great Expectations ZenML store backend instance.

    Args:
        prefix: Subpath prefix to use for this store backend.
        kwargs: Additional keyword arguments passed by the Great Expectations
            core. These are transparently passed to the `TupleStoreBackend`
            constructor.
    """
    super().__init__(**kwargs)

    repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
    artifact_store = repo.active_stack.artifact_store
    self.root_path = os.path.join(artifact_store.path, "great_expectations")

    # extract the protocol used in the artifact store root path
    protocols = [
        scheme
        for scheme in artifact_store.SUPPORTED_SCHEMES
        if self.root_path.startswith(scheme)
    ]
    if protocols:
        self.proto = protocols[0]
    else:
        self.proto = ""

    if prefix:
        if self.platform_specific_separator:
            prefix = prefix.strip(os.sep)
        prefix = prefix.strip("/")
    self.prefix = prefix

    # Initialize with store_backend_id if not part of an HTMLSiteStore
    if not self._suppress_store_backend_id:
        _ = self.store_backend_id

    self._config = {
        "prefix": prefix,
        "module_name": self.__class__.__module__,
        "class_name": self.__class__.__name__,
    }
    self._config.update(kwargs)
    filter_properties_dict(
        properties=self._config, clean_falsy=True, inplace=True
    )
get_public_url_for_key(self, key, protocol=None)

Get the public URL of an object in the store.

Parameters:

Name Type Description Default
key str

object key identifier.

required
protocol Optional[str]

optional protocol to use instead of the store protocol.

None

Returns:

Type Description
str

The public URL where the object can be accessed.

Exceptions:

Type Description
StoreBackendError

if a base_public_path attribute was not configured for the store.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_public_url_for_key(
    self, key: str, protocol: Optional[str] = None
) -> str:
    """Get the public URL of an object in the store.

    Args:
        key: object key identifier.
        protocol: optional protocol to use instead of the store protocol.

    Returns:
        The public URL where the object can be accessed.

    Raises:
        StoreBackendError: if a `base_public_path` attribute was not
            configured for the store.
    """
    if not self.base_public_path:
        raise StoreBackendError(
            f"Error: No base_public_path was configured! A public URL was "
            f"requested but `base_public_path` was not configured for the "
            f"{self.__class__.__name__}"
        )
    filepath = self._convert_key_to_filepath(key)
    public_url = self.base_public_path + filepath.replace(self.proto, "")
    return cast(str, public_url)
get_url_for_key(self, key, protocol=None)

Get the URL of an object in the store.

Parameters:

Name Type Description Default
key Tuple[str, ...]

object key identifier.

required
protocol Optional[str]

optional protocol to use instead of the store protocol.

None

Returns:

Type Description
str

The URL of the object in the store.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_url_for_key(
    self, key: Tuple[str, ...], protocol: Optional[str] = None
) -> str:
    """Get the URL of an object in the store.

    Args:
        key: object key identifier.
        protocol: optional protocol to use instead of the store protocol.

    Returns:
        The URL of the object in the store.
    """
    filepath = self._build_object_path(key)
    if not protocol and not io_utils.is_remote(filepath):
        protocol = "file:"
    if protocol:
        filepath = filepath.replace(self.proto, f"{protocol}//", 1)

    return filepath
list_keys(self, prefix=())

List the keys of all objects identified by a partial key.

Parameters:

Name Type Description Default
prefix Tuple[str, ...]

partial object key identifier.

()

Returns:

Type Description
List[Tuple[str, ...]]

List of keys identifying all objects present in the store that match the input partial key.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
    """List the keys of all objects identified by a partial key.

    Args:
        prefix: partial object key identifier.

    Returns:
        List of keys identifying all objects present in the store that
        match the input partial key.
    """
    key_list = []
    list_path = self._build_object_path(prefix, is_prefix=True)
    root_path = self._build_object_path(tuple(), is_prefix=True)
    for root, dirs, files in fileio.walk(list_path):
        for file_ in files:
            filepath = os.path.relpath(
                os.path.join(str(root), str(file_)), root_path
            )

            if self.filepath_prefix and not filepath.startswith(
                self.filepath_prefix
            ):
                continue
            elif self.filepath_suffix and not filepath.endswith(
                self.filepath_suffix
            ):
                continue
            key = self._convert_filepath_to_key(filepath)
            if key and not self.is_ignored_key(key):
                key_list.append(key)
    return key_list
remove_key(self, key)

Delete an object from the store.

Parameters:

Name Type Description Default
key Tuple[str, ...]

object key identifier.

required

Returns:

Type Description
bool

True if the object existed in the store and was removed, otherwise False.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def remove_key(self, key: Tuple[str, ...]) -> bool:
    """Delete an object from the store.

    Args:
        key: object key identifier.

    Returns:
        True if the object existed in the store and was removed, otherwise
        False.
    """
    filepath: str = self._build_object_path(key)

    if fileio.exists(filepath):
        fileio.remove(filepath)
        if not io_utils.is_remote(filepath):
            parent_dir = str(Path(filepath).parent)
            self.rrmdir(self.root_path, str(parent_dir))
        return True
    return False
rrmdir(start_path, end_path) staticmethod

Recursively removes empty dirs between start_path and end_path inclusive.

Parameters:

Name Type Description Default
start_path str

Directory to use as a starting point.

required
end_path str

Directory to use as a destination point.

required
Source code in zenml/integrations/great_expectations/ge_store_backend.py
@staticmethod
def rrmdir(start_path: str, end_path: str) -> None:
    """Recursively removes empty dirs between start_path and end_path inclusive.

    Args:
        start_path: Directory to use as a starting point.
        end_path: Directory to use as a destination point.
    """
    while not os.listdir(end_path) and start_path != end_path:
        os.rmdir(end_path)
        end_path = os.path.dirname(end_path)

materializers special

Materializers for Great Expectation serializable objects.

ge_materializer

Implementation of the Great Expectations materializers.

GreatExpectationsMaterializer (BaseMaterializer)

Materializer to read/write Great Expectation objects.

Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
class GreatExpectationsMaterializer(BaseMaterializer):
    """Materializer to read/write Great Expectation objects."""

    ASSOCIATED_TYPES = (
        ExpectationSuite,
        CheckpointResult,
    )
    ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)

    @staticmethod
    def preprocess_checkpoint_result_dict(
        artifact_dict: Dict[str, Any]
    ) -> None:
        """Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.

        The GE CheckpointResult object is not fully de-serializable
        due to some missing code in the GE codebase. We need to compensate
        for this by manually converting some of the attributes to
        their correct data types.

        Args:
            artifact_dict: A dict containing the GE checkpoint result.
        """

        def preprocess_run_result(key: str, value: Any) -> Any:
            if key == "validation_result":
                return ExpectationSuiteValidationResult(**value)
            return value

        artifact_dict["checkpoint_config"] = CheckpointConfig(
            **artifact_dict["checkpoint_config"]
        )
        validation_dict = {}
        for result_ident, results in artifact_dict["run_results"].items():
            validation_ident = (
                ValidationResultIdentifier.from_fixed_length_tuple(
                    result_ident.split("::")[1].split("/")
                )
            )
            validation_results = {
                result_name: preprocess_run_result(result_name, result)
                for result_name, result in results.items()
            }
            validation_dict[validation_ident] = validation_results
        artifact_dict["run_results"] = validation_dict

    def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
        """Reads and returns a Great Expectations object.

        Args:
            data_type: The type of the data to read.

        Returns:
            A loaded Great Expectations object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
        artifact_dict = yaml_utils.read_json(filepath)
        data_type = import_class_by_path(artifact_dict.pop("data_type"))

        if data_type is CheckpointResult:
            self.preprocess_checkpoint_result_dict(artifact_dict)

        return data_type(**artifact_dict)

    def handle_return(self, obj: SerializableDictDot) -> None:
        """Writes a Great Expectations object.

        Args:
            obj: A Great Expectations object.
        """
        super().handle_return(obj)
        filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
        artifact_dict = obj.to_json_dict()
        artifact_type = type(obj)
        artifact_dict[
            "data_type"
        ] = f"{artifact_type.__module__}.{artifact_type.__name__}"
        yaml_utils.write_json(filepath, artifact_dict)
handle_input(self, data_type)

Reads and returns a Great Expectations object.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
SerializableDictDot

A loaded Great Expectations object.

Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
    """Reads and returns a Great Expectations object.

    Args:
        data_type: The type of the data to read.

    Returns:
        A loaded Great Expectations object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
    artifact_dict = yaml_utils.read_json(filepath)
    data_type = import_class_by_path(artifact_dict.pop("data_type"))

    if data_type is CheckpointResult:
        self.preprocess_checkpoint_result_dict(artifact_dict)

    return data_type(**artifact_dict)
handle_return(self, obj)

Writes a Great Expectations object.

Parameters:

Name Type Description Default
obj SerializableDictDot

A Great Expectations object.

required
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_return(self, obj: SerializableDictDot) -> None:
    """Writes a Great Expectations object.

    Args:
        obj: A Great Expectations object.
    """
    super().handle_return(obj)
    filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
    artifact_dict = obj.to_json_dict()
    artifact_type = type(obj)
    artifact_dict[
        "data_type"
    ] = f"{artifact_type.__module__}.{artifact_type.__name__}"
    yaml_utils.write_json(filepath, artifact_dict)
preprocess_checkpoint_result_dict(artifact_dict) staticmethod

Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.

The GE CheckpointResult object is not fully de-serializable due to some missing code in the GE codebase. We need to compensate for this by manually converting some of the attributes to their correct data types.

Parameters:

Name Type Description Default
artifact_dict Dict[str, Any]

A dict containing the GE checkpoint result.

required
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
@staticmethod
def preprocess_checkpoint_result_dict(
    artifact_dict: Dict[str, Any]
) -> None:
    """Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.

    The GE CheckpointResult object is not fully de-serializable
    due to some missing code in the GE codebase. We need to compensate
    for this by manually converting some of the attributes to
    their correct data types.

    Args:
        artifact_dict: A dict containing the GE checkpoint result.
    """

    def preprocess_run_result(key: str, value: Any) -> Any:
        if key == "validation_result":
            return ExpectationSuiteValidationResult(**value)
        return value

    artifact_dict["checkpoint_config"] = CheckpointConfig(
        **artifact_dict["checkpoint_config"]
    )
    validation_dict = {}
    for result_ident, results in artifact_dict["run_results"].items():
        validation_ident = (
            ValidationResultIdentifier.from_fixed_length_tuple(
                result_ident.split("::")[1].split("/")
            )
        )
        validation_results = {
            result_name: preprocess_run_result(result_name, result)
            for result_name, result in results.items()
        }
        validation_dict[validation_ident] = validation_results
    artifact_dict["run_results"] = validation_dict

steps special

Great Expectations data profiling and validation standard steps.

ge_profiler

Great Expectations data profiling standard step.

GreatExpectationsProfilerConfig (BaseStepConfig) pydantic-model

Config class for a Great Expectations profiler step.

Attributes:

Name Type Description
expectation_suite_name str

The name of the expectation suite to create or update.

data_asset_name Optional[str]

The name of the data asset to run the expectation suite on.

profiler_kwargs Optional[Dict[str, Any]]

A dictionary of keyword arguments to pass to the profiler.

overwrite_existing_suite bool

Whether to overwrite an existing expectation suite.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerConfig(BaseStepConfig):
    """Config class for a Great Expectations profiler step.

    Attributes:
        expectation_suite_name: The name of the expectation suite to create
            or update.
        data_asset_name: The name of the data asset to run the expectation suite on.
        profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
        overwrite_existing_suite: Whether to overwrite an existing expectation suite.
    """

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
    overwrite_existing_suite: bool = True
GreatExpectationsProfilerStep (BaseStep)

Standard Great Expectations profiling step implementation.

Use this standard Great Expectations profiling step to build an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerStep(BaseStep):
    """Standard Great Expectations profiling step implementation.

    Use this standard Great Expectations profiling step to build an Expectation
    Suite automatically by running a UserConfigurableProfiler on an input
    dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: GreatExpectationsProfilerConfig,
    ) -> ExpectationSuite:
        """Standard Great Expectations data profiling step entrypoint.

        Args:
            dataset: The dataset from which the expectation suite will be inferred.
            config: The configuration for the step.

        Returns:
            The generated Great Expectations suite.
        """
        data_validator = (
            GreatExpectationsDataValidator.get_active_data_validator()
        )

        return data_validator.data_profiling(
            dataset,
            expectation_suite_name=config.expectation_suite_name,
            data_asset_name=config.data_asset_name,
            profiler_kwargs=config.profiler_kwargs,
            overwrite_existing_suite=config.overwrite_existing_suite,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for a Great Expectations profiler step.

Attributes:

Name Type Description
expectation_suite_name str

The name of the expectation suite to create or update.

data_asset_name Optional[str]

The name of the data asset to run the expectation suite on.

profiler_kwargs Optional[Dict[str, Any]]

A dictionary of keyword arguments to pass to the profiler.

overwrite_existing_suite bool

Whether to overwrite an existing expectation suite.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerConfig(BaseStepConfig):
    """Config class for a Great Expectations profiler step.

    Attributes:
        expectation_suite_name: The name of the expectation suite to create
            or update.
        data_asset_name: The name of the data asset to run the expectation suite on.
        profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
        overwrite_existing_suite: Whether to overwrite an existing expectation suite.
    """

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
    overwrite_existing_suite: bool = True
entrypoint(self, dataset, config)

Standard Great Expectations data profiling step entrypoint.

Parameters:

Name Type Description Default
dataset DataFrame

The dataset from which the expectation suite will be inferred.

required
config GreatExpectationsProfilerConfig

The configuration for the step.

required

Returns:

Type Description
ExpectationSuite

The generated Great Expectations suite.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: GreatExpectationsProfilerConfig,
) -> ExpectationSuite:
    """Standard Great Expectations data profiling step entrypoint.

    Args:
        dataset: The dataset from which the expectation suite will be inferred.
        config: The configuration for the step.

    Returns:
        The generated Great Expectations suite.
    """
    data_validator = (
        GreatExpectationsDataValidator.get_active_data_validator()
    )

    return data_validator.data_profiling(
        dataset,
        expectation_suite_name=config.expectation_suite_name,
        data_asset_name=config.data_asset_name,
        profiler_kwargs=config.profiler_kwargs,
        overwrite_existing_suite=config.overwrite_existing_suite,
    )
great_expectations_profiler_step(step_name, config)

Shortcut function to create a new instance of the GreatExpectationsProfilerStep step.

The returned GreatExpectationsProfilerStep can be used in a pipeline to infer data validation rules from an input pd.DataFrame dataset and return them as an Expectation Suite. The Expectation Suite is also persisted in the Great Expectations expectation store.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config GreatExpectationsProfilerConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a GreatExpectationsProfilerStep step instance

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def great_expectations_profiler_step(
    step_name: str,
    config: GreatExpectationsProfilerConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the GreatExpectationsProfilerStep step.

    The returned GreatExpectationsProfilerStep can be used in a pipeline to
    infer data validation rules from an input pd.DataFrame dataset and return
    them as an Expectation Suite. The Expectation Suite is also persisted in the
    Great Expectations expectation store.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a GreatExpectationsProfilerStep step instance
    """
    return clone_step(GreatExpectationsProfilerStep, step_name)(config=config)
ge_validator

Great Expectations data validation standard step.

GreatExpectationsValidatorConfig (BaseStepConfig) pydantic-model

Config class for a Great Expectations checkpoint step.

Attributes:

Name Type Description
expectation_suite_name str

The name of the expectation suite to use to validate the dataset.

data_asset_name Optional[str]

The name of the data asset to use to identify the dataset in the Great Expectations docs.

action_list Optional[List[Dict[str, Any]]]

A list of additional Great Expectations actions to run after the validation check.

exit_on_error bool

Set this flag to raise an error and exit the pipeline early if the validation fails.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorConfig(BaseStepConfig):
    """Config class for a Great Expectations checkpoint step.

    Attributes:
        expectation_suite_name: The name of the expectation suite to use to
            validate the dataset.
        data_asset_name: The name of the data asset to use to identify the
            dataset in the Great Expectations docs.
        action_list: A list of additional Great Expectations actions to run
            after the validation check.
        exit_on_error: Set this flag to raise an error and exit the pipeline
            early if the validation fails.
    """

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    action_list: Optional[List[Dict[str, Any]]] = None
    exit_on_error: bool = False
GreatExpectationsValidatorStep (BaseStep)

Standard Great Expectations data validation step implementation.

Use this standard Great Expectations data validation step to run an existing Expectation Suite on an input dataset as covered in the official GE documentation.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorStep(BaseStep):
    """Standard Great Expectations data validation step implementation.

    Use this standard Great Expectations data validation step to run an
    existing Expectation Suite on an input dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        condition: bool,
        config: GreatExpectationsValidatorConfig,
    ) -> CheckpointResult:
        """Standard Great Expectations data validation step entrypoint.

        Args:
            dataset: The dataset to run the expectation suite on.
            condition: This dummy argument can be used as a condition to enforce
                that this step is only run after another step has completed. This
                is useful for example if the Expectation Suite used to validate
                the data is computed in a `GreatExpectationsProfilerStep` that
                is part of the same pipeline.
            config: The configuration for the step.

        Returns:
            The Great Expectations validation (checkpoint) result.

        Raises:
            RuntimeError: if the step is configured to exit on error and the
                data validation failed.
        """
        data_validator = (
            GreatExpectationsDataValidator.get_active_data_validator()
        )

        results = data_validator.data_validation(
            dataset,
            expectation_suite_name=config.expectation_suite_name,
            data_asset_name=config.data_asset_name,
            action_list=config.action_list,
        )

        if config.exit_on_error and not results.success():
            raise RuntimeError(
                "The Great Expectations validation failed. Check "
                "the logs or the Great Expectations data docs for more "
                "information."
            )

        return results
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for a Great Expectations checkpoint step.

Attributes:

Name Type Description
expectation_suite_name str

The name of the expectation suite to use to validate the dataset.

data_asset_name Optional[str]

The name of the data asset to use to identify the dataset in the Great Expectations docs.

action_list Optional[List[Dict[str, Any]]]

A list of additional Great Expectations actions to run after the validation check.

exit_on_error bool

Set this flag to raise an error and exit the pipeline early if the validation fails.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorConfig(BaseStepConfig):
    """Config class for a Great Expectations checkpoint step.

    Attributes:
        expectation_suite_name: The name of the expectation suite to use to
            validate the dataset.
        data_asset_name: The name of the data asset to use to identify the
            dataset in the Great Expectations docs.
        action_list: A list of additional Great Expectations actions to run
            after the validation check.
        exit_on_error: Set this flag to raise an error and exit the pipeline
            early if the validation fails.
    """

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    action_list: Optional[List[Dict[str, Any]]] = None
    exit_on_error: bool = False
entrypoint(self, dataset, condition, config)

Standard Great Expectations data validation step entrypoint.

Parameters:

Name Type Description Default
dataset DataFrame

The dataset to run the expectation suite on.

required
condition bool

This dummy argument can be used as a condition to enforce that this step is only run after another step has completed. This is useful for example if the Expectation Suite used to validate the data is computed in a GreatExpectationsProfilerStep that is part of the same pipeline.

required
config GreatExpectationsValidatorConfig

The configuration for the step.

required

Returns:

Type Description
CheckpointResult

The Great Expectations validation (checkpoint) result.

Exceptions:

Type Description
RuntimeError

if the step is configured to exit on error and the data validation failed.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    condition: bool,
    config: GreatExpectationsValidatorConfig,
) -> CheckpointResult:
    """Standard Great Expectations data validation step entrypoint.

    Args:
        dataset: The dataset to run the expectation suite on.
        condition: This dummy argument can be used as a condition to enforce
            that this step is only run after another step has completed. This
            is useful for example if the Expectation Suite used to validate
            the data is computed in a `GreatExpectationsProfilerStep` that
            is part of the same pipeline.
        config: The configuration for the step.

    Returns:
        The Great Expectations validation (checkpoint) result.

    Raises:
        RuntimeError: if the step is configured to exit on error and the
            data validation failed.
    """
    data_validator = (
        GreatExpectationsDataValidator.get_active_data_validator()
    )

    results = data_validator.data_validation(
        dataset,
        expectation_suite_name=config.expectation_suite_name,
        data_asset_name=config.data_asset_name,
        action_list=config.action_list,
    )

    if config.exit_on_error and not results.success():
        raise RuntimeError(
            "The Great Expectations validation failed. Check "
            "the logs or the Great Expectations data docs for more "
            "information."
        )

    return results
great_expectations_validator_step(step_name, config)

Shortcut function to create a new instance of the GreatExpectationsValidatorStep step.

The returned GreatExpectationsValidatorStep can be used in a pipeline to validate an input pd.DataFrame dataset and return the result as a Great Expectations CheckpointResult object. The validation results are also persisted in the Great Expectations validation store.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config GreatExpectationsValidatorConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a GreatExpectationsProfilerStep step instance

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def great_expectations_validator_step(
    step_name: str,
    config: GreatExpectationsValidatorConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the GreatExpectationsValidatorStep step.

    The returned GreatExpectationsValidatorStep can be used in a pipeline to
    validate an input pd.DataFrame dataset and return the result as a Great
    Expectations CheckpointResult object. The validation results are also
    persisted in the Great Expectations validation store.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a GreatExpectationsProfilerStep step instance
    """
    return clone_step(GreatExpectationsValidatorStep, step_name)(config=config)

utils

Great Expectations data profiling standard step.

create_batch_request(context, dataset, data_asset_name)

Create a temporary runtime GE batch request from a dataset step artifact.

Parameters:

Name Type Description Default
context BaseDataContext

Great Expectations data context.

required
dataset DataFrame

Input dataset.

required
data_asset_name Optional[str]

Optional custom name for the data asset.

required

Returns:

Type Description
RuntimeBatchRequest

A Great Expectations runtime batch request.

Source code in zenml/integrations/great_expectations/utils.py
def create_batch_request(
    context: BaseDataContext,
    dataset: pd.DataFrame,
    data_asset_name: Optional[str],
) -> RuntimeBatchRequest:
    """Create a temporary runtime GE batch request from a dataset step artifact.

    Args:
        context: Great Expectations data context.
        dataset: Input dataset.
        data_asset_name: Optional custom name for the data asset.

    Returns:
        A Great Expectations runtime batch request.
    """
    try:
        # get pipeline name, step name and run id
        step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
        pipeline_name = step_env.pipeline_name
        run_id = step_env.pipeline_run_id
        step_name = step_env.step_name
    except KeyError:
        # if not running inside a pipeline step, use random values
        pipeline_name = f"pipeline_{random_str(5)}"
        run_id = f"pipeline_{random_str(5)}"
        step_name = f"step_{random_str(5)}"

    datasource_name = f"{run_id}_{step_name}"
    data_connector_name = datasource_name
    data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}"
    batch_identifier = "default"

    datasource_config = {
        "name": datasource_name,
        "class_name": "Datasource",
        "module_name": "great_expectations.datasource",
        "execution_engine": {
            "module_name": "great_expectations.execution_engine",
            "class_name": "PandasExecutionEngine",
        },
        "data_connectors": {
            data_connector_name: {
                "class_name": "RuntimeDataConnector",
                "batch_identifiers": [batch_identifier],
            },
        },
    }

    context.add_datasource(**datasource_config)
    batch_request = RuntimeBatchRequest(
        datasource_name=datasource_name,
        data_connector_name=data_connector_name,
        data_asset_name=data_asset_name,
        runtime_parameters={"batch_data": dataset},
        batch_identifiers={batch_identifier: batch_identifier},
    )

    return batch_request

visualizers special

Great Expectations visualizers for expectation suites and validation results.

ge_visualizer

Great Expectations visualizers for expectation suites and validation results.

GreatExpectationsVisualizer (BaseStepVisualizer)

The implementation of a Great Expectations Visualizer.

Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
class GreatExpectationsVisualizer(BaseStepVisualizer):
    """The implementation of a Great Expectations Visualizer."""

    def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
        """Method to visualize a Great Expectations resource.

        Args:
            object: StepView fetched from run.get_step().
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        for artifact_view in object.outputs.values():
            # filter out anything but Great Expectations data analysis artifacts
            if (
                artifact_view.type == DataAnalysisArtifact.__name__
                and artifact_view.data_type.startswith("great_expectations.")
            ):
                artifact = artifact_view.read()
                if isinstance(artifact, CheckpointResult):
                    result = cast(CheckpointResult, artifact)
                    identifier = next(iter(result.run_results.keys()))
                else:
                    suite = cast(ExpectationSuite, artifact)
                    identifier = ExpectationSuiteIdentifier(
                        suite.expectation_suite_name
                    )

                context = GreatExpectationsDataValidator.get_data_context()
                context.open_data_docs(identifier)
visualize(self, object, *args, **kwargs)

Method to visualize a Great Expectations resource.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
    """Method to visualize a Great Expectations resource.

    Args:
        object: StepView fetched from run.get_step().
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for artifact_view in object.outputs.values():
        # filter out anything but Great Expectations data analysis artifacts
        if (
            artifact_view.type == DataAnalysisArtifact.__name__
            and artifact_view.data_type.startswith("great_expectations.")
        ):
            artifact = artifact_view.read()
            if isinstance(artifact, CheckpointResult):
                result = cast(CheckpointResult, artifact)
                identifier = next(iter(result.run_results.keys()))
            else:
                suite = cast(ExpectationSuite, artifact)
                identifier = ExpectationSuiteIdentifier(
                    suite.expectation_suite_name
                )

            context = GreatExpectationsDataValidator.get_data_context()
            context.open_data_docs(identifier)

huggingface special

Initialization of the Huggingface integration.

HuggingfaceIntegration (Integration)

Definition of Huggingface integration for ZenML.

Source code in zenml/integrations/huggingface/__init__.py
class HuggingfaceIntegration(Integration):
    """Definition of Huggingface integration for ZenML."""

    NAME = HUGGINGFACE
    REQUIREMENTS = ["transformers", "datasets"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.huggingface import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/huggingface/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.huggingface import materializers  # noqa

materializers special

Initialization of Huggingface materializers.

huggingface_datasets_materializer

Implementation of the Huggingface datasets materializer.

HFDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from huggingface datasets.

Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
class HFDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from huggingface datasets."""

    ASSOCIATED_TYPES = (Dataset, DatasetDict)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Dataset:
        """Reads Dataset.

        Args:
            data_type: The type of the dataset to read.

        Returns:
            The dataset read from the specified dir.
        """
        super().handle_input(data_type)

        return load_from_disk(
            os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
        )

    def handle_return(self, ds: Type[Any]) -> None:
        """Writes a Dataset to the specified dir.

        Args:
            ds: The Dataset to write.
        """
        super().handle_return(ds)
        temp_dir = TemporaryDirectory()
        ds.save_to_disk(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
        )
handle_input(self, data_type)

Reads Dataset.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the dataset to read.

required

Returns:

Type Description
Dataset

The dataset read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
    """Reads Dataset.

    Args:
        data_type: The type of the dataset to read.

    Returns:
        The dataset read from the specified dir.
    """
    super().handle_input(data_type)

    return load_from_disk(
        os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
    )
handle_return(self, ds)

Writes a Dataset to the specified dir.

Parameters:

Name Type Description Default
ds Type[Any]

The Dataset to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_return(self, ds: Type[Any]) -> None:
    """Writes a Dataset to the specified dir.

    Args:
        ds: The Dataset to write.
    """
    super().handle_return(ds)
    temp_dir = TemporaryDirectory()
    ds.save_to_disk(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
    )
huggingface_pt_model_materializer

Implementation of the Huggingface PyTorch model materializer.

HFPTModelMaterializer (BaseMaterializer)

Materializer to read torch model to and from huggingface pretrained model.

Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
class HFPTModelMaterializer(BaseMaterializer):
    """Materializer to read torch model to and from huggingface pretrained model."""

    ASSOCIATED_TYPES = (PreTrainedModel,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
        """Reads HFModel.

        Args:
            data_type: The type of the model to read.

        Returns:
            The model read from the specified dir.
        """
        super().handle_input(data_type)

        config = AutoConfig.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
        )
        architecture = config.architectures[0]
        model_cls = getattr(
            importlib.import_module("transformers"), architecture
        )
        return model_cls.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
        )

    def handle_return(self, model: Type[Any]) -> None:
        """Writes a Model to the specified dir.

        Args:
            model: The Torch Model to write.
        """
        super().handle_return(model)
        temp_dir = TemporaryDirectory()
        model.save_pretrained(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name,
            os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
        )
handle_input(self, data_type)

Reads HFModel.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to read.

required

Returns:

Type Description
PreTrainedModel

The model read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
    """Reads HFModel.

    Args:
        data_type: The type of the model to read.

    Returns:
        The model read from the specified dir.
    """
    super().handle_input(data_type)

    config = AutoConfig.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
    )
    architecture = config.architectures[0]
    model_cls = getattr(
        importlib.import_module("transformers"), architecture
    )
    return model_cls.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
    )
handle_return(self, model)

Writes a Model to the specified dir.

Parameters:

Name Type Description Default
model Type[Any]

The Torch Model to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
    """Writes a Model to the specified dir.

    Args:
        model: The Torch Model to write.
    """
    super().handle_return(model)
    temp_dir = TemporaryDirectory()
    model.save_pretrained(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name,
        os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
    )
huggingface_tf_model_materializer

Implementation of the Huggingface TF model materializer.

HFTFModelMaterializer (BaseMaterializer)

Materializer to read Tensorflow model to and from huggingface pretrained model.

Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
class HFTFModelMaterializer(BaseMaterializer):
    """Materializer to read Tensorflow model to and from huggingface pretrained model."""

    ASSOCIATED_TYPES = (TFPreTrainedModel,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
        """Reads HFModel.

        Args:
            data_type: The type of the model to read.

        Returns:
            The model read from the specified dir.
        """
        super().handle_input(data_type)

        config = AutoConfig.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
        )
        architecture = "TF" + config.architectures[0]
        model_cls = getattr(
            importlib.import_module("transformers"), architecture
        )
        return model_cls.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
        )

    def handle_return(self, model: Type[Any]) -> None:
        """Writes a Model to the specified dir.

        Args:
            model: The TF Model to write.
        """
        super().handle_return(model)
        temp_dir = TemporaryDirectory()
        model.save_pretrained(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name,
            os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
        )
handle_input(self, data_type)

Reads HFModel.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to read.

required

Returns:

Type Description
TFPreTrainedModel

The model read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
    """Reads HFModel.

    Args:
        data_type: The type of the model to read.

    Returns:
        The model read from the specified dir.
    """
    super().handle_input(data_type)

    config = AutoConfig.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
    )
    architecture = "TF" + config.architectures[0]
    model_cls = getattr(
        importlib.import_module("transformers"), architecture
    )
    return model_cls.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
    )
handle_return(self, model)

Writes a Model to the specified dir.

Parameters:

Name Type Description Default
model Type[Any]

The TF Model to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
    """Writes a Model to the specified dir.

    Args:
        model: The TF Model to write.
    """
    super().handle_return(model)
    temp_dir = TemporaryDirectory()
    model.save_pretrained(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name,
        os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
    )
huggingface_tokenizer_materializer

Implementation of the Huggingface tokenizer materializer.

HFTokenizerMaterializer (BaseMaterializer)

Materializer to read tokenizer to and from huggingface tokenizer.

Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
class HFTokenizerMaterializer(BaseMaterializer):
    """Materializer to read tokenizer to and from huggingface tokenizer."""

    ASSOCIATED_TYPES = (PreTrainedTokenizerBase,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
        """Reads Tokenizer.

        Args:
            data_type: The type of the tokenizer to read.

        Returns:
            The tokenizer read from the specified dir.
        """
        super().handle_input(data_type)

        return AutoTokenizer.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
        )

    def handle_return(self, tokenizer: Type[Any]) -> None:
        """Writes a Tokenizer to the specified dir.

        Args:
            tokenizer: The HFTokenizer to write.
        """
        super().handle_return(tokenizer)
        temp_dir = TemporaryDirectory()
        tokenizer.save_pretrained(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name,
            os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
        )
handle_input(self, data_type)

Reads Tokenizer.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the tokenizer to read.

required

Returns:

Type Description
PreTrainedTokenizerBase

The tokenizer read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
    """Reads Tokenizer.

    Args:
        data_type: The type of the tokenizer to read.

    Returns:
        The tokenizer read from the specified dir.
    """
    super().handle_input(data_type)

    return AutoTokenizer.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
    )
handle_return(self, tokenizer)

Writes a Tokenizer to the specified dir.

Parameters:

Name Type Description Default
tokenizer Type[Any]

The HFTokenizer to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_return(self, tokenizer: Type[Any]) -> None:
    """Writes a Tokenizer to the specified dir.

    Args:
        tokenizer: The HFTokenizer to write.
    """
    super().handle_return(tokenizer)
    temp_dir = TemporaryDirectory()
    tokenizer.save_pretrained(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name,
        os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
    )

integration

Base and meta classes for ZenML integrations.

Integration

Base class for integration in ZenML.

Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
    """Base class for integration in ZenML."""

    NAME = "base_integration"

    REQUIREMENTS: List[str] = []

    SYSTEM_REQUIREMENTS: Dict[str, str] = {}

    @classmethod
    def check_installation(cls) -> bool:
        """Method to check whether the required packages are installed.

        Returns:
            True if all required packages are installed, False otherwise.
        """
        try:
            for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
                result = shutil.which(command)

                if result is None:
                    logger.debug(
                        "Unable to find the required packages for %s on your "
                        "system. Please install the packages on your system "
                        "and try again.",
                        requirement,
                    )
                    return False

            for r in cls.REQUIREMENTS:
                pkg_resources.get_distribution(r)
            logger.debug(
                f"Integration {cls.NAME} is installed correctly with "
                f"requirements {cls.REQUIREMENTS}."
            )
            return True
        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"VersionConflict error when loading installation {cls.NAME}: "
                f"{str(e)}"
            )
            return False

    @classmethod
    def activate(cls) -> None:
        """Abstract method to activate the integration."""

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Abstract method to declare new stack component flavors."""
activate() classmethod

Abstract method to activate the integration.

Source code in zenml/integrations/integration.py
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    try:
        for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
            result = shutil.which(command)

            if result is None:
                logger.debug(
                    "Unable to find the required packages for %s on your "
                    "system. Please install the packages on your system "
                    "and try again.",
                    requirement,
                )
                return False

        for r in cls.REQUIREMENTS:
            pkg_resources.get_distribution(r)
        logger.debug(
            f"Integration {cls.NAME} is installed correctly with "
            f"requirements {cls.REQUIREMENTS}."
        )
        return True
    except pkg_resources.DistributionNotFound as e:
        logger.debug(
            f"Unable to find required package '{e.req}' for "
            f"integration {cls.NAME}."
        )
        return False
    except pkg_resources.VersionConflict as e:
        logger.debug(
            f"VersionConflict error when loading installation {cls.NAME}: "
            f"{str(e)}"
        )
        return False
flavors() classmethod

Abstract method to declare new stack component flavors.

Source code in zenml/integrations/integration.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Abstract method to declare new stack component flavors."""

IntegrationMeta (type)

Metaclass responsible for registering different Integration subclasses.

Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
    """Metaclass responsible for registering different Integration subclasses."""

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "IntegrationMeta":
        """Hook into creation of an Integration class.

        Args:
            name: The name of the class being created.
            bases: The base classes of the class being created.
            dct: The dictionary of attributes of the class being created.

        Returns:
            The newly created class.
        """
        cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
        if name != "Integration":
            integration_registry.register_integration(cls.NAME, cls)
        return cls
__new__(mcs, name, bases, dct) special staticmethod

Hook into creation of an Integration class.

Parameters:

Name Type Description Default
name str

The name of the class being created.

required
bases Tuple[Type[Any], ...]

The base classes of the class being created.

required
dct Dict[str, Any]

The dictionary of attributes of the class being created.

required

Returns:

Type Description
IntegrationMeta

The newly created class.

Source code in zenml/integrations/integration.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
    """Hook into creation of an Integration class.

    Args:
        name: The name of the class being created.
        bases: The base classes of the class being created.
        dct: The dictionary of attributes of the class being created.

    Returns:
        The newly created class.
    """
    cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
    if name != "Integration":
        integration_registry.register_integration(cls.NAME, cls)
    return cls

kserve special

Initialization of the KServe integration for ZenML.

The KServe integration allows you to use the KServe model serving platform to implement continuous model deployment.

KServeIntegration (Integration)

Definition of KServe integration for ZenML.

Source code in zenml/integrations/kserve/__init__.py
class KServeIntegration(Integration):
    """Definition of KServe integration for ZenML."""

    NAME = KSERVE
    REQUIREMENTS = [
        "kserve==0.9.0",
        "torch-model-archiver",
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the Seldon Core integration."""
        from zenml.integrations.kserve import model_deployers  # noqa
        from zenml.integrations.kserve import secret_schemas  # noqa
        from zenml.integrations.kserve import services  # noqa
        from zenml.integrations.kserve import steps  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for KServe.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=KSERVE_MODEL_DEPLOYER_FLAVOR,
                source="zenml.integrations.kserve.model_deployers.KServeModelDeployer",
                type=StackComponentType.MODEL_DEPLOYER,
                integration=cls.NAME,
            )
        ]
activate() classmethod

Activate the Seldon Core integration.

Source code in zenml/integrations/kserve/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the Seldon Core integration."""
    from zenml.integrations.kserve import model_deployers  # noqa
    from zenml.integrations.kserve import secret_schemas  # noqa
    from zenml.integrations.kserve import services  # noqa
    from zenml.integrations.kserve import steps  # noqa
flavors() classmethod

Declare the stack component flavors for KServe.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/kserve/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for KServe.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=KSERVE_MODEL_DEPLOYER_FLAVOR,
            source="zenml.integrations.kserve.model_deployers.KServeModelDeployer",
            type=StackComponentType.MODEL_DEPLOYER,
            integration=cls.NAME,
        )
    ]

model_deployers special

Initialization of the KServe Model Deployer.

kserve_model_deployer

Implementation of the KServe Model Deployer.

KServeModelDeployer (BaseModelDeployer) pydantic-model

KServe model deployer stack component implementation.

Attributes:

Name Type Description
kubernetes_context Optional[str]

the Kubernetes context to use to contact the remote KServe installation. If not specified, the current configuration is used. Depending on where the KServe model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod).

kubernetes_namespace Optional[str]

the Kubernetes namespace where the KServe inference service CRDs are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the KServe model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod).

base_url str

the base URL of the Kubernetes ingress used to expose the KServe inference services.

secret Optional[str]

the name of the secret containing the credentials for the KServe inference services.

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
class KServeModelDeployer(BaseModelDeployer):
    """KServe model deployer stack component implementation.

    Attributes:
        kubernetes_context: the Kubernetes context to use to contact the remote
            KServe installation. If not specified, the current
            configuration is used. Depending on where the KServe model deployer
            is being used, this can be either a locally active context or an
            in-cluster Kubernetes configuration (if running inside a pod).
        kubernetes_namespace: the Kubernetes namespace where the KServe
            inference service CRDs are provisioned and managed by ZenML. If not
            specified, the namespace set in the current configuration is used.
            Depending on where the KServe model deployer is being used, this can
            be either the current namespace configured in the locally active
            context or the namespace in the context of which the pod is running
            (if running inside a pod).
        base_url: the base URL of the Kubernetes ingress used to expose the
            KServe inference services.
        secret: the name of the secret containing the credentials for the
            KServe inference services.
    """

    # Class Configuration
    FLAVOR: ClassVar[str] = KSERVE_MODEL_DEPLOYER_FLAVOR

    kubernetes_context: Optional[str]
    kubernetes_namespace: Optional[str]
    base_url: str
    secret: Optional[str]
    custom_domain: Optional[str]

    # private attributes
    _client: Optional[KServeClient] = None

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "KServeDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information on the model server.

        Args:
            service_instance: KServe deployment service object

        Returns:
            A dictionary containing the model server information.
        """
        return {
            "PREDICTION_URL": service_instance.prediction_url,
            "PREDICTION_HOSTNAME": service_instance.prediction_hostname,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "KSERVE_INFERENCE_SERVICE": service_instance.crd_name,
        }

    @staticmethod
    def get_active_model_deployer() -> "KServeModelDeployer":
        """Get the KServe model deployer registered in the active stack.

        Returns:
            The KServe model deployer registered in the active stack.

        Raises:
            TypeError: if the KServe model deployer is not available.
        """
        model_deployer = Repository(  # type: ignore [call-arg]
            skip_repository_check=True
        ).active_stack.model_deployer
        if not model_deployer or not isinstance(
            model_deployer, KServeModelDeployer
        ):
            raise TypeError(
                f"The active stack needs to have a KServe model deployer "
                f"component registered to be able to deploy models with KServe "
                f"You can create a new stack with a KServe model "
                f"deployer component or update your existing stack to add this "
                f"component, e.g.:\n\n"
                f"  'zenml model-deployer register kserve --flavor={KSERVE_MODEL_DEPLOYER_FLAVOR} "
                f"--kubernetes_context=context-name --kubernetes_namespace="
                f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
                f"  'zenml stack create stack-name -d kserve ...'\n"
            )
        return model_deployer

    @property
    def kserve_client(self) -> KServeClient:
        """Get the KServe client associated with this model deployer.

        Returns:
            The KServeclient.
        """
        if not self._client:
            self._client = KServeClient(
                context=self.kubernetes_context,
            )
        return self._client

    def _set_credentials(self) -> None:
        """Set the credentials for the given service instance.

        Raises:
            RuntimeError: if the credentials are not available.
        """
        secret = self._get_kserve_secret()
        if secret:
            secret_folder = Path(
                GlobalConfiguration().config_directory,
                "kserve-storage",
                str(self.uuid),
            )
            kserve_credentials = {}
            # Handle the secrets attributes
            for key in secret.content.keys():
                content = getattr(secret, key)
                if key == "credentials" and content:
                    fileio.makedirs(str(secret_folder))
                    file_path = Path(secret_folder, f"{key}.json")
                    kserve_credentials["credentials_file"] = str(file_path)
                    with open(file_path, "w") as f:
                        f.write(content)
                    file_path.chmod(0o600)
                # Handle additional params
                else:
                    kserve_credentials[key] = content

            # We need to add the namespace to the kserve_credentials
            kserve_credentials["namespace"] = (
                self.kubernetes_namespace
                or utils.get_default_target_namespace()
            )

            try:
                self.kserve_client.set_credentials(**kserve_credentials)
            except Exception as e:
                raise RuntimeError(
                    f"Failed to set credentials for KServe model deployer: {e}"
                )
            finally:
                if file_path.exists():
                    file_path.unlink()

    def deploy_model(
        self,
        config: ServiceConfig,
        replace: bool = False,
        timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new KServe deployment or update an existing one.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new KServe
            deployment server to reflect the model and other configuration
            parameters specified in the supplied KServe deployment `config`.

          * if `replace` is True, this method will first attempt to find an
            existing KServe deployment that is *equivalent* to the supplied
            configuration parameters. Two or more KServe deployments are
            considered equivalent if they have the same `pipeline_name`,
            `pipeline_step_name` and `model_name` configuration parameters. To
            put it differently, two KServe deployments are equivalent if
            they serve versions of the same model deployed by the same pipeline
            step. If an equivalent KServe deployment is found, it will be
            updated in place to reflect the new configuration parameters. This
            allows an existing KServe deployment to retain its prediction
            URL while performing a rolling update to serve a new model version.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new KServe deployment
        server for each new model version. If multiple equivalent KServe
        deployments are found, the most recently created deployment is selected
        to be updated and the others are deleted.

        Args:
            config: the configuration of the model to be deployed with KServe.
            replace: set this flag to True to find and update an equivalent
                KServeDeployment server with the new model instead of
                starting a new deployment server.
            timeout: the timeout in seconds to wait for the KServe server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the KServe
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML KServe deployment service object that can be used to
            interact with the remote KServe server.

        Raises:
            RuntimeError: if the KServe deployment server could not be stopped.
        """
        config = cast(KServeDeploymentConfig, config)
        service = None

        # if the secret is passed in the config, use it to set the credentials
        if config.secret_name:
            self.secret = config.secret_name or self.secret
        self._set_credentials()

        # if replace is True, find equivalent KServe deployments
        if replace is True:
            equivalent_services = self.find_model_server(
                running=False,
                pipeline_name=config.pipeline_name,
                pipeline_step_name=config.pipeline_step_name,
                model_name=config.model_name,
            )

            for equivalent_service in equivalent_services:
                if service is None:
                    # keep the most recently created service
                    service = equivalent_service
                else:
                    try:
                        # delete the older services and don't wait for them to
                        # be deprovisioned
                        service.stop()
                    except RuntimeError as e:
                        raise RuntimeError(
                            "Failed to stop the KServe deployment server:\n",
                            f"{e}\n",
                            "Please stop it manually and try again.",
                        )
        if service:
            # update an equivalent service in place
            service.update(config)
            logger.info(
                f"Updating an existing KServe deployment service: {service}"
            )
        else:
            # create a new service
            service = KServeDeploymentService(config=config)
            logger.info(f"Creating a new KServe deployment service: {service}")

        # start the service which in turn provisions the KServe
        # deployment server and waits for it to reach a ready state
        service.start(timeout=timeout)
        return service

    def get_kserve_deployments(
        self, labels: Dict[str, str]
    ) -> List[V1beta1InferenceService]:
        """Get a list of KServe deployments that match the supplied labels.

        Args:
            labels: a dictionary of labels to match against KServe deployments.

        Returns:
            A list of KServe deployments that match the supplied labels.

        Raises:
            RuntimeError: if an operational failure is encountered while
        """
        label_selector = (
            ",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
        )

        namespace = (
            self.kubernetes_namespace or utils.get_default_target_namespace()
        )

        try:
            response = (
                self.kserve_client.api_instance.list_namespaced_custom_object(
                    constants.KSERVE_GROUP,
                    constants.KSERVE_V1BETA1_VERSION,
                    namespace,
                    constants.KSERVE_PLURAL,
                    label_selector=label_selector,
                )
            )
        except client.rest.ApiException as e:
            raise RuntimeError(
                "Exception when retrieving KServe inference services\
                %s\n"
                % e
            )

        # TODO[CRITICAL]: de-serialize each item into a complete
        #   V1beta1InferenceService object recursively using the OpenApi
        #   schema (this doesn't work right now)
        inference_services: List[V1beta1InferenceService] = []
        for item in response.get("items", []):
            snake_case_item = self._camel_to_snake(item)
            inference_service = V1beta1InferenceService(**snake_case_item)
            inference_services.append(inference_service)
        return inference_services

    def _camel_to_snake(self, obj: Dict[str, Any]) -> Dict[str, Any]:
        """Convert a camelCase dictionary to snake_case.

        Args:
            obj: a dictionary with camelCase keys

        Returns:
            a dictionary with snake_case keys
        """
        if isinstance(obj, (str, int, float)):
            return obj
        if isinstance(obj, dict):
            assert obj is not None
            new = obj.__class__()
            for k, v in obj.items():
                new[self._convert_to_snake(k)] = self._camel_to_snake(v)
        elif isinstance(obj, (list, set, tuple)):
            assert obj is not None
            new = obj.__class__(self._camel_to_snake(v) for v in obj)
        else:
            return obj
        return new

    def _convert_to_snake(self, k: str) -> str:
        return re.sub(r"(?<!^)(?=[A-Z])", "_", k).lower()

    def find_model_server(
        self,
        running: bool = False,
        service_uuid: Optional[UUID] = None,
        pipeline_name: Optional[str] = None,
        pipeline_run_id: Optional[str] = None,
        pipeline_step_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_uri: Optional[str] = None,
        predictor: Optional[str] = None,
    ) -> List[BaseService]:
        """Find one or more KServe model services that match the given criteria.

        Args:
            running: If true, only running services will be returned.
            service_uuid: The UUID of the service that was originally used
                to deploy the model.
            pipeline_name: name of the pipeline that the deployed model was part
                of.
            pipeline_run_id: ID of the pipeline run which the deployed model was
                part of.
            pipeline_step_name: the name of the pipeline model deployment step
                that deployed the model.
            model_name: the name of the deployed model.
            model_uri: URI of the deployed model.
            predictor: the name of the predictor that was used to deploy the model.

        Returns:
            One or more Service objects representing model servers that match
            the input search criteria.
        """
        config = KServeDeploymentConfig(
            pipeline_name=pipeline_name or "",
            pipeline_run_id=pipeline_run_id or "",
            pipeline_step_name=pipeline_step_name or "",
            model_uri=model_uri or "",
            model_name=model_name or "",
            predictor=predictor or "",
            resources={},
        )
        labels = config.get_kubernetes_labels()

        if service_uuid:
            labels["zenml.service_uuid"] = str(service_uuid)

        deployments = self.get_kserve_deployments(labels=labels)

        services: List[BaseService] = []
        for deployment in deployments:
            # recreate the KServe deployment service object from the KServe
            # deployment resource
            service = KServeDeploymentService.create_from_deployment(
                deployment=deployment
            )
            if running and not service.is_running:
                # skip non-running services
                continue
            services.append(service)

        return services

    def stop_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Stop a KServe model server.

        Args:
            uuid: UUID of the model server to stop.
            timeout: timeout in seconds to wait for the service to stop.
            force: if True, force the service to stop.

        Raises:
            NotImplementedError: stopping on KServe model servers is not
                supported.
        """
        raise NotImplementedError(
            "Stopping KServe model servers is not implemented. Try "
            "deleting the KServe model server instead."
        )

    def start_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
    ) -> None:
        """Start a KServe model deployment server.

        Args:
            uuid: UUID of the model server to start.
            timeout: timeout in seconds to wait for the service to become
                active. . If set to 0, the method will return immediately after
                provisioning the service, without waiting for it to become
                active.

        Raises:
            NotImplementedError: since we don't support starting KServe
                model servers
        """
        raise NotImplementedError(
            "Starting KServe model servers is not implemented"
        )

    def delete_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Delete a KServe model deployment server.

        Args:
            uuid: UUID of the model server to delete.
            timeout: timeout in seconds to wait for the service to stop. If
                set to 0, the method will return immediately after
                deprovisioning the service, without waiting for it to stop.
            force: if True, force the service to stop.
        """
        services = self.find_model_server(service_uuid=uuid)
        if len(services) == 0:
            return
        services[0].stop(timeout=timeout, force=force)

    def _get_kserve_secret(self) -> Any:
        """Get the secret object for the KServe deployment.

        Returns:
            The secret object for the KServe deployment.

        Raises:
            RuntimeError: if the secret object is not found or secrets_manager is not set.
        """
        if self.secret:

            secret_manager = Repository(  # type: ignore [call-arg]
                skip_repository_check=True
            ).active_stack.secrets_manager

            if not secret_manager or not isinstance(
                secret_manager, BaseSecretsManager
            ):
                raise RuntimeError(
                    f"The active stack doesn't have a secret manager component. "
                    f"The ZenML secret specified in the KServe Model "
                    f"Deployer configuration cannot be fetched: {self.secret}."
                )
            try:
                secret = secret_manager.get_secret(self.secret)
                return secret
            except KeyError:
                raise RuntimeError(
                    f"The secret `{self.secret}` used for your KServe Model"
                    f"Deployer configuration does not exist in your secrets "
                    f"manager `{secret_manager.name}`."
                )
        return None
kserve_client: KServeClient property readonly

Get the KServe client associated with this model deployer.

Returns:

Type Description
KServeClient

The KServeclient.

delete_model_server(self, uuid, timeout=300, force=False)

Delete a KServe model deployment server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to delete.

required
timeout int

timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop.

300
force bool

if True, force the service to stop.

False
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def delete_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Delete a KServe model deployment server.

    Args:
        uuid: UUID of the model server to delete.
        timeout: timeout in seconds to wait for the service to stop. If
            set to 0, the method will return immediately after
            deprovisioning the service, without waiting for it to stop.
        force: if True, force the service to stop.
    """
    services = self.find_model_server(service_uuid=uuid)
    if len(services) == 0:
        return
    services[0].stop(timeout=timeout, force=force)
deploy_model(self, config, replace=False, timeout=300)

Create a new KServe deployment or update an existing one.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new KServe deployment server to reflect the model and other configuration parameters specified in the supplied KServe deployment config.

  • if replace is True, this method will first attempt to find an existing KServe deployment that is equivalent to the supplied configuration parameters. Two or more KServe deployments are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two KServe deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent KServe deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing KServe deployment to retain its prediction URL while performing a rolling update to serve a new model version.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new KServe deployment server for each new model version. If multiple equivalent KServe deployments are found, the most recently created deployment is selected to be updated and the others are deleted.

Parameters:

Name Type Description Default
config ServiceConfig

the configuration of the model to be deployed with KServe.

required
replace bool

set this flag to True to find and update an equivalent KServeDeployment server with the new model instead of starting a new deployment server.

False
timeout int

the timeout in seconds to wait for the KServe server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the KServe server is provisioned, without waiting for it to fully start.

300

Returns:

Type Description
BaseService

The ZenML KServe deployment service object that can be used to interact with the remote KServe server.

Exceptions:

Type Description
RuntimeError

if the KServe deployment server could not be stopped.

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def deploy_model(
    self,
    config: ServiceConfig,
    replace: bool = False,
    timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new KServe deployment or update an existing one.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new KServe
        deployment server to reflect the model and other configuration
        parameters specified in the supplied KServe deployment `config`.

      * if `replace` is True, this method will first attempt to find an
        existing KServe deployment that is *equivalent* to the supplied
        configuration parameters. Two or more KServe deployments are
        considered equivalent if they have the same `pipeline_name`,
        `pipeline_step_name` and `model_name` configuration parameters. To
        put it differently, two KServe deployments are equivalent if
        they serve versions of the same model deployed by the same pipeline
        step. If an equivalent KServe deployment is found, it will be
        updated in place to reflect the new configuration parameters. This
        allows an existing KServe deployment to retain its prediction
        URL while performing a rolling update to serve a new model version.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new KServe deployment
    server for each new model version. If multiple equivalent KServe
    deployments are found, the most recently created deployment is selected
    to be updated and the others are deleted.

    Args:
        config: the configuration of the model to be deployed with KServe.
        replace: set this flag to True to find and update an equivalent
            KServeDeployment server with the new model instead of
            starting a new deployment server.
        timeout: the timeout in seconds to wait for the KServe server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the KServe
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML KServe deployment service object that can be used to
        interact with the remote KServe server.

    Raises:
        RuntimeError: if the KServe deployment server could not be stopped.
    """
    config = cast(KServeDeploymentConfig, config)
    service = None

    # if the secret is passed in the config, use it to set the credentials
    if config.secret_name:
        self.secret = config.secret_name or self.secret
    self._set_credentials()

    # if replace is True, find equivalent KServe deployments
    if replace is True:
        equivalent_services = self.find_model_server(
            running=False,
            pipeline_name=config.pipeline_name,
            pipeline_step_name=config.pipeline_step_name,
            model_name=config.model_name,
        )

        for equivalent_service in equivalent_services:
            if service is None:
                # keep the most recently created service
                service = equivalent_service
            else:
                try:
                    # delete the older services and don't wait for them to
                    # be deprovisioned
                    service.stop()
                except RuntimeError as e:
                    raise RuntimeError(
                        "Failed to stop the KServe deployment server:\n",
                        f"{e}\n",
                        "Please stop it manually and try again.",
                    )
    if service:
        # update an equivalent service in place
        service.update(config)
        logger.info(
            f"Updating an existing KServe deployment service: {service}"
        )
    else:
        # create a new service
        service = KServeDeploymentService(config=config)
        logger.info(f"Creating a new KServe deployment service: {service}")

    # start the service which in turn provisions the KServe
    # deployment server and waits for it to reach a ready state
    service.start(timeout=timeout)
    return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, predictor=None)

Find one or more KServe model services that match the given criteria.

Parameters:

Name Type Description Default
running bool

If true, only running services will be returned.

False
service_uuid Optional[uuid.UUID]

The UUID of the service that was originally used to deploy the model.

None
pipeline_name Optional[str]

name of the pipeline that the deployed model was part of.

None
pipeline_run_id Optional[str]

ID of the pipeline run which the deployed model was part of.

None
pipeline_step_name Optional[str]

the name of the pipeline model deployment step that deployed the model.

None
model_name Optional[str]

the name of the deployed model.

None
model_uri Optional[str]

URI of the deployed model.

None
predictor Optional[str]

the name of the predictor that was used to deploy the model.

None

Returns:

Type Description
List[zenml.services.service.BaseService]

One or more Service objects representing model servers that match the input search criteria.

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def find_model_server(
    self,
    running: bool = False,
    service_uuid: Optional[UUID] = None,
    pipeline_name: Optional[str] = None,
    pipeline_run_id: Optional[str] = None,
    pipeline_step_name: Optional[str] = None,
    model_name: Optional[str] = None,
    model_uri: Optional[str] = None,
    predictor: Optional[str] = None,
) -> List[BaseService]:
    """Find one or more KServe model services that match the given criteria.

    Args:
        running: If true, only running services will be returned.
        service_uuid: The UUID of the service that was originally used
            to deploy the model.
        pipeline_name: name of the pipeline that the deployed model was part
            of.
        pipeline_run_id: ID of the pipeline run which the deployed model was
            part of.
        pipeline_step_name: the name of the pipeline model deployment step
            that deployed the model.
        model_name: the name of the deployed model.
        model_uri: URI of the deployed model.
        predictor: the name of the predictor that was used to deploy the model.

    Returns:
        One or more Service objects representing model servers that match
        the input search criteria.
    """
    config = KServeDeploymentConfig(
        pipeline_name=pipeline_name or "",
        pipeline_run_id=pipeline_run_id or "",
        pipeline_step_name=pipeline_step_name or "",
        model_uri=model_uri or "",
        model_name=model_name or "",
        predictor=predictor or "",
        resources={},
    )
    labels = config.get_kubernetes_labels()

    if service_uuid:
        labels["zenml.service_uuid"] = str(service_uuid)

    deployments = self.get_kserve_deployments(labels=labels)

    services: List[BaseService] = []
    for deployment in deployments:
        # recreate the KServe deployment service object from the KServe
        # deployment resource
        service = KServeDeploymentService.create_from_deployment(
            deployment=deployment
        )
        if running and not service.is_running:
            # skip non-running services
            continue
        services.append(service)

    return services
get_active_model_deployer() staticmethod

Get the KServe model deployer registered in the active stack.

Returns:

Type Description
KServeModelDeployer

The KServe model deployer registered in the active stack.

Exceptions:

Type Description
TypeError

if the KServe model deployer is not available.

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "KServeModelDeployer":
    """Get the KServe model deployer registered in the active stack.

    Returns:
        The KServe model deployer registered in the active stack.

    Raises:
        TypeError: if the KServe model deployer is not available.
    """
    model_deployer = Repository(  # type: ignore [call-arg]
        skip_repository_check=True
    ).active_stack.model_deployer
    if not model_deployer or not isinstance(
        model_deployer, KServeModelDeployer
    ):
        raise TypeError(
            f"The active stack needs to have a KServe model deployer "
            f"component registered to be able to deploy models with KServe "
            f"You can create a new stack with a KServe model "
            f"deployer component or update your existing stack to add this "
            f"component, e.g.:\n\n"
            f"  'zenml model-deployer register kserve --flavor={KSERVE_MODEL_DEPLOYER_FLAVOR} "
            f"--kubernetes_context=context-name --kubernetes_namespace="
            f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
            f"  'zenml stack create stack-name -d kserve ...'\n"
        )
    return model_deployer
get_kserve_deployments(self, labels)

Get a list of KServe deployments that match the supplied labels.

Parameters:

Name Type Description Default
labels Dict[str, str]

a dictionary of labels to match against KServe deployments.

required

Returns:

Type Description
List[kserve.models.v1beta1_inference_service.V1beta1InferenceService]

A list of KServe deployments that match the supplied labels.

Exceptions:

Type Description
RuntimeError

if an operational failure is encountered while

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def get_kserve_deployments(
    self, labels: Dict[str, str]
) -> List[V1beta1InferenceService]:
    """Get a list of KServe deployments that match the supplied labels.

    Args:
        labels: a dictionary of labels to match against KServe deployments.

    Returns:
        A list of KServe deployments that match the supplied labels.

    Raises:
        RuntimeError: if an operational failure is encountered while
    """
    label_selector = (
        ",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
    )

    namespace = (
        self.kubernetes_namespace or utils.get_default_target_namespace()
    )

    try:
        response = (
            self.kserve_client.api_instance.list_namespaced_custom_object(
                constants.KSERVE_GROUP,
                constants.KSERVE_V1BETA1_VERSION,
                namespace,
                constants.KSERVE_PLURAL,
                label_selector=label_selector,
            )
        )
    except client.rest.ApiException as e:
        raise RuntimeError(
            "Exception when retrieving KServe inference services\
            %s\n"
            % e
        )

    # TODO[CRITICAL]: de-serialize each item into a complete
    #   V1beta1InferenceService object recursively using the OpenApi
    #   schema (this doesn't work right now)
    inference_services: List[V1beta1InferenceService] = []
    for item in response.get("items", []):
        snake_case_item = self._camel_to_snake(item)
        inference_service = V1beta1InferenceService(**snake_case_item)
        inference_services.append(inference_service)
    return inference_services
get_model_server_info(service_instance) staticmethod

Return implementation specific information on the model server.

Parameters:

Name Type Description Default
service_instance KServeDeploymentService

KServe deployment service object

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the model server information.

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "KServeDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information on the model server.

    Args:
        service_instance: KServe deployment service object

    Returns:
        A dictionary containing the model server information.
    """
    return {
        "PREDICTION_URL": service_instance.prediction_url,
        "PREDICTION_HOSTNAME": service_instance.prediction_hostname,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "KSERVE_INFERENCE_SERVICE": service_instance.crd_name,
    }
start_model_server(self, uuid, timeout=300)

Start a KServe model deployment server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to start.

required
timeout int

timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active.

300

Exceptions:

Type Description
NotImplementedError

since we don't support starting KServe model servers

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def start_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
    """Start a KServe model deployment server.

    Args:
        uuid: UUID of the model server to start.
        timeout: timeout in seconds to wait for the service to become
            active. . If set to 0, the method will return immediately after
            provisioning the service, without waiting for it to become
            active.

    Raises:
        NotImplementedError: since we don't support starting KServe
            model servers
    """
    raise NotImplementedError(
        "Starting KServe model servers is not implemented"
    )
stop_model_server(self, uuid, timeout=300, force=False)

Stop a KServe model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to stop.

required
timeout int

timeout in seconds to wait for the service to stop.

300
force bool

if True, force the service to stop.

False

Exceptions:

Type Description
NotImplementedError

stopping on KServe model servers is not supported.

Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def stop_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Stop a KServe model server.

    Args:
        uuid: UUID of the model server to stop.
        timeout: timeout in seconds to wait for the service to stop.
        force: if True, force the service to stop.

    Raises:
        NotImplementedError: stopping on KServe model servers is not
            supported.
    """
    raise NotImplementedError(
        "Stopping KServe model servers is not implemented. Try "
        "deleting the KServe model server instead."
    )

secret_schemas special

Initialization of Kserve Secret Schemas.

These are secret schemas that can be used to authenticate Kserve to the Artifact Store used to store served ML models.

secret_schemas

Implementation for KServe secret schemas.

KServeAzureSecretSchema (BaseSecretSchema) pydantic-model

KServe Azure Blob Storage credentials.

Attributes:

Name Type Description
storage_type Literal['Azure']

the storage type. Must be set to "GCS" for this schema.

credentials Optional[str]

the credentials to use.

Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeAzureSecretSchema(BaseSecretSchema):
    """KServe Azure Blob Storage credentials.

    Attributes:
        storage_type: the storage type. Must be set to "GCS" for this schema.
        credentials: the credentials to use.
    """

    TYPE: ClassVar[str] = KSERVE_AZUREBLOB_SECRET_SCHEMA_TYPE

    storage_type: Literal["Azure"] = "Azure"
    credentials: Optional[str]
KServeGSSecretSchema (BaseSecretSchema) pydantic-model

KServe GCS credentials.

Attributes:

Name Type Description
storage_type Literal['GCS']

the storage type. Must be set to "GCS" for this schema.

credentials Optional[str]

the credentials to use.

service_account Optional[str]

the service account.

Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeGSSecretSchema(BaseSecretSchema):
    """KServe GCS credentials.

    Attributes:
        storage_type: the storage type. Must be set to "GCS" for this schema.
        credentials: the credentials to use.
        service_account: the service account.
    """

    TYPE: ClassVar[str] = KSERVE_GS_SECRET_SCHEMA_TYPE

    storage_type: Literal["GCS"] = "GCS"
    credentials: Optional[str]
    service_account: Optional[str]
KServeS3SecretSchema (BaseSecretSchema) pydantic-model

KServe S3 credentials.

Attributes:

Name Type Description
storage_type Literal['S3']

the storage type. Must be set to "s3" for this schema.

credentials Optional[str]

the credentials to use.

service_account Optional[str]

the name of the service account.

s3_endpoint Optional[str]

the S3 endpoint.

s3_region Optional[str]

the S3 region.

s3_use_https Optional[str]

whether to use HTTPS.

s3_verify_ssl Optional[str]

whether to verify SSL.

Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeS3SecretSchema(BaseSecretSchema):
    """KServe S3 credentials.

    Attributes:
        storage_type: the storage type. Must be set to "s3" for this schema.
        credentials: the credentials to use.
        service_account: the name of the service account.
        s3_endpoint: the S3 endpoint.
        s3_region: the S3 region.
        s3_use_https: whether to use HTTPS.
        s3_verify_ssl: whether to verify SSL.
    """

    TYPE: ClassVar[str] = KSERVE_S3_SECRET_SCHEMA_TYPE

    storage_type: Literal["S3"] = "S3"
    credentials: Optional[str]
    service_account: Optional[str]
    s3_endpoint: Optional[str]
    s3_region: Optional[str]
    s3_use_https: Optional[str]
    s3_verify_ssl: Optional[str]

services special

Initialization for KServe services.

kserve_deployment

Implementation for the KServe inference service.

KServeDeploymentConfig (ServiceConfig) pydantic-model

KServe deployment service configuration.

Attributes:

Name Type Description
model_uri str

URI of the model (or models) to serve.

model_name str

the name of the model. Multiple versions of the same model should use the same model name.

predictor str

the KServe predictor used to serve the model.

replicas int

number of replicas to use for the prediction service.

resources Optional[Dict[str, Any]]

the Kubernetes resources to allocate for the prediction service.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
class KServeDeploymentConfig(ServiceConfig):
    """KServe deployment service configuration.

    Attributes:
        model_uri: URI of the model (or models) to serve.
        model_name: the name of the model. Multiple versions of the same model
            should use the same model name.
        predictor: the KServe predictor used to serve the model.
        replicas: number of replicas to use for the prediction service.
        resources: the Kubernetes resources to allocate for the prediction service.
    """

    model_uri: str = ""
    model_name: str
    secret_name: Optional[str]
    predictor: str
    replicas: int = 1
    resources: Optional[Dict[str, Any]]

    @staticmethod
    def sanitize_labels(labels: Dict[str, str]) -> None:
        """Update the label values to be valid Kubernetes labels.

        See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

        Args:
            labels: The labels to sanitize.
        """
        # TODO[MEDIUM]: Move k8s label sanitization to a common module with all K8s utils.
        for key, value in labels.items():
            # Kubernetes labels must be alphanumeric, no longer than
            # 63 characters, and must begin and end with an alphanumeric
            # character ([a-z0-9A-Z])
            labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
                "-_."
            )

    def get_kubernetes_labels(self) -> Dict[str, str]:
        """Generate the labels for the KServe inference CRD from the service configuration.

        These labels are attached to the KServe inference service CRD
        and may be used as label selectors in lookup operations.

        Returns:
            The labels for the KServe inference service CRD.
        """
        labels = {"app": "zenml"}
        if self.pipeline_name:
            labels["zenml.pipeline_name"] = self.pipeline_name
        if self.pipeline_run_id:
            labels["zenml.pipeline_run_id"] = self.pipeline_run_id
        if self.pipeline_step_name:
            labels["zenml.pipeline_step_name"] = self.pipeline_step_name
        if self.model_name:
            labels["zenml.model_name"] = self.model_name
        if self.model_uri:
            labels["zenml.model_uri"] = self.model_uri
        if self.predictor:
            labels["zenml.model_type"] = self.predictor
        self.sanitize_labels(labels)
        return labels

    def get_kubernetes_annotations(self) -> Dict[str, str]:
        """Generate the annotations for the KServe inference CRD the service configuration.

        The annotations are used to store additional information about the
        KServe ZenML service associated with the deployment that is
        not available on the labels. One annotation is particularly important
        is the serialized Service configuration itself, which is used to
        recreate the service configuration from a remote KServe inference
        service CRD.

        Returns:
            The annotations for the KServe inference service CRD.
        """
        annotations = {
            "zenml.service_config": self.json(),
            "zenml.version": __version__,
        }
        return annotations

    @classmethod
    def create_from_deployment(
        cls, deployment: V1beta1InferenceService
    ) -> "KServeDeploymentConfig":
        """Recreate a KServe service from a KServe deployment resource.

        Args:
            deployment: the KServe inference service CRD.

        Returns:
            The KServe ZenML service configuration corresponding to the given
            KServe inference service CRD.

        Raises:
            ValueError: if the given deployment resource does not contain
                the expected annotations or it contains an invalid or
                incompatible KServe ZenML service configuration.
        """
        config_data = deployment.metadata.get("annotations").get(
            "zenml.service_config"
        )
        if not config_data:
            raise ValueError(
                f"The given deployment resource does not contain a "
                f"'zenml.service_config' annotation: {deployment}"
            )
        try:
            service_config = cls.parse_raw(config_data)
        except ValidationError as e:
            raise ValueError(
                f"The loaded KServe Inference Service resource contains an "
                f"invalid or incompatible KServe ZenML service configuration: "
                f"{config_data}"
            ) from e
        return service_config
create_from_deployment(deployment) classmethod

Recreate a KServe service from a KServe deployment resource.

Parameters:

Name Type Description Default
deployment V1beta1InferenceService

the KServe inference service CRD.

required

Returns:

Type Description
KServeDeploymentConfig

The KServe ZenML service configuration corresponding to the given KServe inference service CRD.

Exceptions:

Type Description
ValueError

if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible KServe ZenML service configuration.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
@classmethod
def create_from_deployment(
    cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentConfig":
    """Recreate a KServe service from a KServe deployment resource.

    Args:
        deployment: the KServe inference service CRD.

    Returns:
        The KServe ZenML service configuration corresponding to the given
        KServe inference service CRD.

    Raises:
        ValueError: if the given deployment resource does not contain
            the expected annotations or it contains an invalid or
            incompatible KServe ZenML service configuration.
    """
    config_data = deployment.metadata.get("annotations").get(
        "zenml.service_config"
    )
    if not config_data:
        raise ValueError(
            f"The given deployment resource does not contain a "
            f"'zenml.service_config' annotation: {deployment}"
        )
    try:
        service_config = cls.parse_raw(config_data)
    except ValidationError as e:
        raise ValueError(
            f"The loaded KServe Inference Service resource contains an "
            f"invalid or incompatible KServe ZenML service configuration: "
            f"{config_data}"
        ) from e
    return service_config
get_kubernetes_annotations(self)

Generate the annotations for the KServe inference CRD the service configuration.

The annotations are used to store additional information about the KServe ZenML service associated with the deployment that is not available on the labels. One annotation is particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote KServe inference service CRD.

Returns:

Type Description
Dict[str, str]

The annotations for the KServe inference service CRD.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_kubernetes_annotations(self) -> Dict[str, str]:
    """Generate the annotations for the KServe inference CRD the service configuration.

    The annotations are used to store additional information about the
    KServe ZenML service associated with the deployment that is
    not available on the labels. One annotation is particularly important
    is the serialized Service configuration itself, which is used to
    recreate the service configuration from a remote KServe inference
    service CRD.

    Returns:
        The annotations for the KServe inference service CRD.
    """
    annotations = {
        "zenml.service_config": self.json(),
        "zenml.version": __version__,
    }
    return annotations
get_kubernetes_labels(self)

Generate the labels for the KServe inference CRD from the service configuration.

These labels are attached to the KServe inference service CRD and may be used as label selectors in lookup operations.

Returns:

Type Description
Dict[str, str]

The labels for the KServe inference service CRD.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_kubernetes_labels(self) -> Dict[str, str]:
    """Generate the labels for the KServe inference CRD from the service configuration.

    These labels are attached to the KServe inference service CRD
    and may be used as label selectors in lookup operations.

    Returns:
        The labels for the KServe inference service CRD.
    """
    labels = {"app": "zenml"}
    if self.pipeline_name:
        labels["zenml.pipeline_name"] = self.pipeline_name
    if self.pipeline_run_id:
        labels["zenml.pipeline_run_id"] = self.pipeline_run_id
    if self.pipeline_step_name:
        labels["zenml.pipeline_step_name"] = self.pipeline_step_name
    if self.model_name:
        labels["zenml.model_name"] = self.model_name
    if self.model_uri:
        labels["zenml.model_uri"] = self.model_uri
    if self.predictor:
        labels["zenml.model_type"] = self.predictor
    self.sanitize_labels(labels)
    return labels
sanitize_labels(labels) staticmethod

Update the label values to be valid Kubernetes labels.

See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

Parameters:

Name Type Description Default
labels Dict[str, str]

The labels to sanitize.

required
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
    """Update the label values to be valid Kubernetes labels.

    See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

    Args:
        labels: The labels to sanitize.
    """
    # TODO[MEDIUM]: Move k8s label sanitization to a common module with all K8s utils.
    for key, value in labels.items():
        # Kubernetes labels must be alphanumeric, no longer than
        # 63 characters, and must begin and end with an alphanumeric
        # character ([a-z0-9A-Z])
        labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
            "-_."
        )
KServeDeploymentService (BaseService) pydantic-model

A ZenML service that represents a KServe inference service CRD.

Attributes:

Name Type Description
config KServeDeploymentConfig

service configuration.

status ServiceStatus

service status.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
class KServeDeploymentService(BaseService):
    """A ZenML service that represents a KServe inference service CRD.

    Attributes:
        config: service configuration.
        status: service status.
    """

    SERVICE_TYPE = ServiceType(
        name="kserve-deployment",
        type="model-serving",
        flavor="kserve",
        description="KServe inference service",
    )

    config: KServeDeploymentConfig = Field(
        default_factory=KServeDeploymentConfig
    )
    status: ServiceStatus = Field(default_factory=ServiceStatus)

    def _get_model_deployer(self) -> "KServeModelDeployer":
        """Get the active KServe model deployer.

        Returns:
            The active KServeModelDeployer.

        Raises:
            TypeError: if the current stack has no KServeModelDeployer.
        """
        from zenml.integrations.kserve.model_deployers.kserve_model_deployer import (
            KServeModelDeployer,
        )

        try:
            model_deployer = KServeModelDeployer.get_active_model_deployer()
        except TypeError:
            raise TypeError(
                "No active KServe model deployer is present in the active "
                "stack. Please make sure that a KServe model deployer is "
                "present in the active stack."
            )
        return model_deployer

    def _get_client(self) -> KServeClient:
        """Get the KServe client from the active KServe model deployer.

        Returns:
            The KServe client.
        """
        return self._get_model_deployer().kserve_client

    def _get_namespace(self) -> Optional[str]:
        """Get the Kubernetes namespace from the active KServe model deployer.

        Returns:
            The Kubernetes namespace, or None, if the default namespace is
            used.
        """
        return self._get_model_deployer().kubernetes_namespace

    def check_status(self) -> Tuple[ServiceState, str]:
        """Check the state of the KServe inference service.

        This method Checks the current operational state of the external KServe
        inference service and translate it into a `ServiceState` value and a printable message.

        This method should be overridden by subclasses that implement concrete service tracking functionality.

        Returns:
            The operational state of the external service and a message
            providing additional information about that state (e.g. a
            description of the error if one is encountered while checking the
            service status).
        """
        client = self._get_client()
        namespace = self._get_namespace()

        name = self.crd_name
        try:
            deployment = client.get(name=name, namespace=namespace)
        except RuntimeError:
            return (ServiceState.INACTIVE, "")

        # TODO[MEDIUM]: Implement better operational status checking that also
        #   cover errors
        if "status" not in deployment:
            return (ServiceState.INACTIVE, "No operational status available")
        status = "Unknown"
        for condition in deployment["status"].get("conditions", {}):
            if condition.get("type", "") == "PredictorReady":
                status = condition.get("status", "Unknown")
                if status.lower() == "true":
                    return (
                        ServiceState.ACTIVE,
                        f"Inference service '{name}' is available",
                    )

                elif status.lower() == "false":
                    return (
                        ServiceState.PENDING_STARTUP,
                        f"Inference service '{name}' is not available: {condition.get('message', 'Unknown')}",
                    )
        return (
            ServiceState.PENDING_STARTUP,
            f"Inference service '{name}' still starting up",
        )

    @property
    def crd_name(self) -> str:
        """Get the name of the KServe inference service CRD that uniquely corresponds to this service instance.

        Returns:
            The name of the KServe inference service CRD.
        """
        return (
            self._get_kubernetes_labels().get("zenml.model_name")
            or f"zenml-{str(self.uuid)[:8]}"
        )

    def _get_kubernetes_labels(self) -> Dict[str, str]:
        """Generate the labels for the KServe inference service CRD from the service configuration.

        Returns:
            The labels for the KServe inference service.
        """
        labels = self.config.get_kubernetes_labels()
        labels["zenml.service_uuid"] = str(self.uuid)
        KServeDeploymentConfig.sanitize_labels(labels)
        return labels

    @classmethod
    def create_from_deployment(
        cls, deployment: V1beta1InferenceService
    ) -> "KServeDeploymentService":
        """Recreate the configuration of a KServe Service from a deployed instance.

        Args:
            deployment: the KServe deployment resource.

        Returns:
            The KServe service configuration corresponding to the given
            KServe deployment resource.

        Raises:
            ValueError: if the given deployment resource does not contain
                the expected annotations or it contains an invalid or
                incompatible KServe service configuration.
        """
        config = KServeDeploymentConfig.create_from_deployment(deployment)
        uuid = deployment.metadata.get("labels").get("zenml.service_uuid")
        if not uuid:
            raise ValueError(
                f"The given deployment resource does not contain a valid "
                f"'zenml.service_uuid' label: {deployment}"
            )
        service = cls(uuid=UUID(uuid), config=config)
        service.update_status()
        return service

    def provision(self) -> None:
        """Provision or update remote KServe deployment instance.

        This should then match the current configuration.
        """
        client = self._get_client()
        namespace = self._get_namespace()

        api_version = constants.KSERVE_GROUP + "/" + "v1beta1"
        name = self.crd_name

        # All supported model specs seem to have the same fields
        # so we can use any one of them (see https://kserve.github.io/website/0.8/reference/api/#serving.kserve.io/v1beta1.PredictorExtensionSpec)
        predictor_kwargs = {
            self.config.predictor: V1beta1PredictorExtensionSpec(
                storage_uri=self.config.model_uri,
                resources=self.config.resources,
            )
        }

        isvc = V1beta1InferenceService(
            api_version=api_version,
            kind=constants.KSERVE_KIND,
            metadata=k8s_client.V1ObjectMeta(
                name=name,
                namespace=namespace,
                labels=self._get_kubernetes_labels(),
                annotations=self.config.get_kubernetes_annotations(),
            ),
            spec=V1beta1InferenceServiceSpec(
                predictor=V1beta1PredictorSpec(**predictor_kwargs)
            ),
        )

        # TODO[HIGH]: better error handling when provisioning KServe instances
        try:
            client.get(name=name, namespace=namespace)
            # update the existing deployment
            client.replace(name, isvc, namespace=namespace)
        except RuntimeError:
            client.create(isvc)

    def deprovision(self, force: bool = False) -> None:
        """Deprovisions all resources used by the service.

        Args:
            force: if True, the service will be deprovisioned even if it is
                still in use.

        Raises:
            ValueError: if the service is still in use and force is False.
        """
        client = self._get_client()
        namespace = self._get_namespace()
        name = self.crd_name

        # TODO[HIGH]: catch errors if deleting a KServe instance that is no
        #   longer available
        try:
            client.delete(name=name, namespace=namespace)
        except RuntimeError:
            raise ValueError(
                f"Could not delete KServe instance '{name}' from namespace: '{namespace}'."
            )

    def _get_deployment_logs(
        self,
        name: str,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Get the logs of a KServe deployment resource.

        Args:
            name: the name of the KServe deployment to get logs for.
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.

        Raises:
            Exception: if an unknown error occurs while fetching the logs.

        Yields:
            The logs of the given deployment.
        """
        client = self._get_client()
        namespace = self._get_namespace()

        logger.debug(f"Retrieving logs for InferenceService resource: {name}")
        try:
            response = client.core_api.list_namespaced_pod(
                namespace=namespace,
                label_selector=f"zenml.service_uuid={self.uuid}",
            )
            logger.debug("Kubernetes API response: %s", response)
            pods = response.items
            if not pods:
                raise Exception(
                    f"The KServe deployment {name} is not currently "
                    f"running: no Kubernetes pods associated with it were found"
                )
            pod = pods[0]
            pod_name = pod.metadata.name

            containers = [c.name for c in pod.spec.containers]
            init_containers = [c.name for c in pod.spec.init_containers]
            container_statuses = {
                c.name: c.started or c.restart_count
                for c in pod.status.container_statuses
            }

            container = "default"
            if container not in containers:
                container = containers[0]

            if not container_statuses[container]:
                container = init_containers[0]

            logger.info(
                f"Retrieving logs for pod: `{pod_name}` and container "
                f"`{container}` in namespace `{namespace}`"
            )
            response = client.core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=namespace,
                container=container,
                follow=follow,
                tail_lines=tail,
                _preload_content=False,
            )
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when fetching logs for InferenceService resource "
                "%s: %s",
                name,
                str(e),
            )
            raise Exception(
                f"Unexpected exception when fetching logs for InferenceService "
                f"resource: {name}"
            ) from e

        try:
            while True:
                line = response.readline().decode("utf-8").rstrip("\n")
                if not line:
                    return
                stop = yield line
                if stop:
                    return
        finally:
            response.release_conn()

    def get_logs(
        self, follow: bool = False, tail: Optional[int] = None
    ) -> Generator[str, bool, None]:
        """Retrieve the logs from the remote KServe inference service instance.

        Args:
            follow: if True, the logs will be streamed as they are written.
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.
        """
        return self._get_deployment_logs(
            self.crd_name,
            follow=follow,
            tail=tail,
        )

    @property
    def prediction_url(self) -> Optional[str]:
        """The prediction URI exposed by the prediction service.

        Returns:
            The prediction URI exposed by the prediction service, or None if
            the service is not yet ready.
        """
        if not self.is_running:
            return None

        model_deployer = self._get_model_deployer()
        return os.path.join(
            model_deployer.base_url,
            "v1/models",
            f"{self.crd_name}:predict",
        )

    @property
    def prediction_hostname(self) -> Optional[str]:
        """The prediction hostname exposed by the prediction service.

        Returns:
            The prediction hostname exposed by the prediction service status
            that will be used in the headers of the prediction request.
        """
        if not self.is_running:
            return None

        namespace = self._get_namespace()

        model_deployer = self._get_model_deployer()
        custom_domain = model_deployer.custom_domain or "example.com"
        return f"{self.crd_name}.{namespace}.{custom_domain}"

    def predict(self, request: str) -> Any:
        """Make a prediction using the service.

        Args:
            request: a NumPy array representing the request

        Returns:
            A NumPy array represents the prediction returned by the service.

        Raises:
            Exception: if the service is not yet ready.
            ValueError: if the prediction_url is not set.
        """
        if not self.is_running:
            raise Exception(
                "KServe prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.prediction_url is None:
            raise ValueError("`self.prediction_url` is not set, cannot post.")
        if self.prediction_hostname is None:
            raise ValueError(
                "`self.prediction_hostname` is not set, cannot post."
            )
        headers = {"Host": self.prediction_hostname}
        if isinstance(request, str):
            request = json.loads(request)
        else:
            raise ValueError("Request must be a json string.")
        response = requests.post(
            self.prediction_url,
            headers=headers,
            json={"instances": request},
        )
        response.raise_for_status()
        return response.json()["predictions"]
crd_name: str property readonly

Get the name of the KServe inference service CRD that uniquely corresponds to this service instance.

Returns:

Type Description
str

The name of the KServe inference service CRD.

prediction_hostname: Optional[str] property readonly

The prediction hostname exposed by the prediction service.

Returns:

Type Description
Optional[str]

The prediction hostname exposed by the prediction service status that will be used in the headers of the prediction request.

prediction_url: Optional[str] property readonly

The prediction URI exposed by the prediction service.

Returns:

Type Description
Optional[str]

The prediction URI exposed by the prediction service, or None if the service is not yet ready.

check_status(self)

Check the state of the KServe inference service.

This method Checks the current operational state of the external KServe inference service and translate it into a ServiceState value and a printable message.

This method should be overridden by subclasses that implement concrete service tracking functionality.

Returns:

Type Description
Tuple[zenml.services.service_status.ServiceState, str]

The operational state of the external service and a message providing additional information about that state (e.g. a description of the error if one is encountered while checking the service status).

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
    """Check the state of the KServe inference service.

    This method Checks the current operational state of the external KServe
    inference service and translate it into a `ServiceState` value and a printable message.

    This method should be overridden by subclasses that implement concrete service tracking functionality.

    Returns:
        The operational state of the external service and a message
        providing additional information about that state (e.g. a
        description of the error if one is encountered while checking the
        service status).
    """
    client = self._get_client()
    namespace = self._get_namespace()

    name = self.crd_name
    try:
        deployment = client.get(name=name, namespace=namespace)
    except RuntimeError:
        return (ServiceState.INACTIVE, "")

    # TODO[MEDIUM]: Implement better operational status checking that also
    #   cover errors
    if "status" not in deployment:
        return (ServiceState.INACTIVE, "No operational status available")
    status = "Unknown"
    for condition in deployment["status"].get("conditions", {}):
        if condition.get("type", "") == "PredictorReady":
            status = condition.get("status", "Unknown")
            if status.lower() == "true":
                return (
                    ServiceState.ACTIVE,
                    f"Inference service '{name}' is available",
                )

            elif status.lower() == "false":
                return (
                    ServiceState.PENDING_STARTUP,
                    f"Inference service '{name}' is not available: {condition.get('message', 'Unknown')}",
                )
    return (
        ServiceState.PENDING_STARTUP,
        f"Inference service '{name}' still starting up",
    )
create_from_deployment(deployment) classmethod

Recreate the configuration of a KServe Service from a deployed instance.

Parameters:

Name Type Description Default
deployment V1beta1InferenceService

the KServe deployment resource.

required

Returns:

Type Description
KServeDeploymentService

The KServe service configuration corresponding to the given KServe deployment resource.

Exceptions:

Type Description
ValueError

if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible KServe service configuration.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
@classmethod
def create_from_deployment(
    cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentService":
    """Recreate the configuration of a KServe Service from a deployed instance.

    Args:
        deployment: the KServe deployment resource.

    Returns:
        The KServe service configuration corresponding to the given
        KServe deployment resource.

    Raises:
        ValueError: if the given deployment resource does not contain
            the expected annotations or it contains an invalid or
            incompatible KServe service configuration.
    """
    config = KServeDeploymentConfig.create_from_deployment(deployment)
    uuid = deployment.metadata.get("labels").get("zenml.service_uuid")
    if not uuid:
        raise ValueError(
            f"The given deployment resource does not contain a valid "
            f"'zenml.service_uuid' label: {deployment}"
        )
    service = cls(uuid=UUID(uuid), config=config)
    service.update_status()
    return service
deprovision(self, force=False)

Deprovisions all resources used by the service.

Parameters:

Name Type Description Default
force bool

if True, the service will be deprovisioned even if it is still in use.

False

Exceptions:

Type Description
ValueError

if the service is still in use and force is False.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def deprovision(self, force: bool = False) -> None:
    """Deprovisions all resources used by the service.

    Args:
        force: if True, the service will be deprovisioned even if it is
            still in use.

    Raises:
        ValueError: if the service is still in use and force is False.
    """
    client = self._get_client()
    namespace = self._get_namespace()
    name = self.crd_name

    # TODO[HIGH]: catch errors if deleting a KServe instance that is no
    #   longer available
    try:
        client.delete(name=name, namespace=namespace)
    except RuntimeError:
        raise ValueError(
            f"Could not delete KServe instance '{name}' from namespace: '{namespace}'."
        )
get_logs(self, follow=False, tail=None)

Retrieve the logs from the remote KServe inference service instance.

Parameters:

Name Type Description Default
follow bool

if True, the logs will be streamed as they are written.

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_logs(
    self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
    """Retrieve the logs from the remote KServe inference service instance.

    Args:
        follow: if True, the logs will be streamed as they are written.
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.
    """
    return self._get_deployment_logs(
        self.crd_name,
        follow=follow,
        tail=tail,
    )
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request str

a NumPy array representing the request

required

Returns:

Type Description
Any

A NumPy array represents the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not yet ready.

ValueError

if the prediction_url is not set.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def predict(self, request: str) -> Any:
    """Make a prediction using the service.

    Args:
        request: a NumPy array representing the request

    Returns:
        A NumPy array represents the prediction returned by the service.

    Raises:
        Exception: if the service is not yet ready.
        ValueError: if the prediction_url is not set.
    """
    if not self.is_running:
        raise Exception(
            "KServe prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.prediction_url is None:
        raise ValueError("`self.prediction_url` is not set, cannot post.")
    if self.prediction_hostname is None:
        raise ValueError(
            "`self.prediction_hostname` is not set, cannot post."
        )
    headers = {"Host": self.prediction_hostname}
    if isinstance(request, str):
        request = json.loads(request)
    else:
        raise ValueError("Request must be a json string.")
    response = requests.post(
        self.prediction_url,
        headers=headers,
        json={"instances": request},
    )
    response.raise_for_status()
    return response.json()["predictions"]
provision(self)

Provision or update remote KServe deployment instance.

This should then match the current configuration.

Source code in zenml/integrations/kserve/services/kserve_deployment.py
def provision(self) -> None:
    """Provision or update remote KServe deployment instance.

    This should then match the current configuration.
    """
    client = self._get_client()
    namespace = self._get_namespace()

    api_version = constants.KSERVE_GROUP + "/" + "v1beta1"
    name = self.crd_name

    # All supported model specs seem to have the same fields
    # so we can use any one of them (see https://kserve.github.io/website/0.8/reference/api/#serving.kserve.io/v1beta1.PredictorExtensionSpec)
    predictor_kwargs = {
        self.config.predictor: V1beta1PredictorExtensionSpec(
            storage_uri=self.config.model_uri,
            resources=self.config.resources,
        )
    }

    isvc = V1beta1InferenceService(
        api_version=api_version,
        kind=constants.KSERVE_KIND,
        metadata=k8s_client.V1ObjectMeta(
            name=name,
            namespace=namespace,
            labels=self._get_kubernetes_labels(),
            annotations=self.config.get_kubernetes_annotations(),
        ),
        spec=V1beta1InferenceServiceSpec(
            predictor=V1beta1PredictorSpec(**predictor_kwargs)
        ),
    )

    # TODO[HIGH]: better error handling when provisioning KServe instances
    try:
        client.get(name=name, namespace=namespace)
        # update the existing deployment
        client.replace(name, isvc, namespace=namespace)
    except RuntimeError:
        client.create(isvc)

steps special

Initialization for KServe steps.

kserve_deployer

Implementation of the KServe Deployer step.

KServeDeployerStepConfig (BaseStepConfig) pydantic-model

KServe model deployer step configuration.

Attributes:

Name Type Description
service_config KServeDeploymentConfig

KServe deployment service configuration.

torch_serve_params

TorchServe set of parameters to deploy model.

timeout int

Timeout for model deployment.

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepConfig(BaseStepConfig):
    """KServe model deployer step configuration.

    Attributes:
        service_config: KServe deployment service configuration.
        torch_serve_params: TorchServe set of parameters to deploy model.
        timeout: Timeout for model deployment.
    """

    service_config: KServeDeploymentConfig
    torch_serve_parameters: Optional[TorchServeParameters] = None
    timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
TorchServeParameters (BaseModel) pydantic-model

KServe PyTorch model deployer step configuration.

Attributes:

Name Type Description
service_config

KServe deployment service configuration.

model_class str

Path to Python file containing model architecture.

handler str

TorchServe's handler file to handle custom TorchServe inference logic.

extra_files Optional[List[str]]

Comma separated path to extra dependency files.

model_version Optional[str]

Model version.

requirements_file Optional[str]

Path to requirements file.

torch_config Optional[str]

TorchServe configuration file path.

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class TorchServeParameters(BaseModel):
    """KServe PyTorch model deployer step configuration.

    Attributes:
        service_config: KServe deployment service configuration.
        model_class: Path to Python file containing model architecture.
        handler: TorchServe's handler file to handle custom TorchServe inference logic.
        extra_files: Comma separated path to extra dependency files.
        model_version: Model version.
        requirements_file: Path to requirements file.
        torch_config: TorchServe configuration file path.
    """

    model_class: str
    handler: str
    extra_files: Optional[List[str]] = None
    requirements_file: Optional[str] = None
    model_version: Optional[str] = "1.0"
    torch_config: Optional[str] = None

    @validator("model_class")
    def model_class_validate(cls, v: str) -> str:
        """Validate model class file path.

        Args:
            v: model class file path

        Returns:
            model class file path

        Raises:
            ValueError: if model class file path is not valid
        """
        if not v:
            raise ValueError("Model class file path is required.")
        if not is_inside_repository(v):
            raise ValueError(
                "Model class file path must be inside the repository."
            )
        return v

    @validator("handler")
    def handler_validate(cls, v: str) -> str:
        """Validate handler.

        Args:
            v: handler file path

        Returns:
            handler file path

        Raises:
            ValueError: if handler file path is not valid
        """
        if v:
            if v in TORCH_HANDLERS:
                return v
            elif is_inside_repository(v):
                return v
            else:
                raise ValueError(
                    "Handler must be one of the TorchServe handlers",
                    "or a file that exists inside the repository.",
                )
        else:
            raise ValueError("Handler is required.")

    @validator("extra_files")
    def extra_files_validate(
        cls, v: Optional[List[str]]
    ) -> Optional[List[str]]:
        """Validate extra files.

        Args:
            v: extra files path

        Returns:
            extra files path

        Raises:
            ValueError: if the extra files path is not valid
        """
        extra_files = []
        if v is not None:
            for file_path in v:
                if is_inside_repository(file_path):
                    extra_files.append(file_path)
                else:
                    raise ValueError(
                        "Extra file path must be inside the repository."
                    )
            return extra_files
        return v

    @validator("torch_config")
    def torch_config_validate(cls, v: Optional[str]) -> Optional[str]:
        """Validate torch config file.

        Args:
            v: torch config file path

        Returns:
            torch config file path

        Raises:
            ValueError: if torch config file path is not valid.
        """
        if v:
            if is_inside_repository(v):
                return v
            else:
                raise ValueError(
                    "Torch config file path must be inside the repository."
                )
        return v
extra_files_validate(v) classmethod

Validate extra files.

Parameters:

Name Type Description Default
v Optional[List[str]]

extra files path

required

Returns:

Type Description
Optional[List[str]]

extra files path

Exceptions:

Type Description
ValueError

if the extra files path is not valid

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("extra_files")
def extra_files_validate(
    cls, v: Optional[List[str]]
) -> Optional[List[str]]:
    """Validate extra files.

    Args:
        v: extra files path

    Returns:
        extra files path

    Raises:
        ValueError: if the extra files path is not valid
    """
    extra_files = []
    if v is not None:
        for file_path in v:
            if is_inside_repository(file_path):
                extra_files.append(file_path)
            else:
                raise ValueError(
                    "Extra file path must be inside the repository."
                )
        return extra_files
    return v
handler_validate(v) classmethod

Validate handler.

Parameters:

Name Type Description Default
v str

handler file path

required

Returns:

Type Description
str

handler file path

Exceptions:

Type Description
ValueError

if handler file path is not valid

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("handler")
def handler_validate(cls, v: str) -> str:
    """Validate handler.

    Args:
        v: handler file path

    Returns:
        handler file path

    Raises:
        ValueError: if handler file path is not valid
    """
    if v:
        if v in TORCH_HANDLERS:
            return v
        elif is_inside_repository(v):
            return v
        else:
            raise ValueError(
                "Handler must be one of the TorchServe handlers",
                "or a file that exists inside the repository.",
            )
    else:
        raise ValueError("Handler is required.")
model_class_validate(v) classmethod

Validate model class file path.

Parameters:

Name Type Description Default
v str

model class file path

required

Returns:

Type Description
str

model class file path

Exceptions:

Type Description
ValueError

if model class file path is not valid

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("model_class")
def model_class_validate(cls, v: str) -> str:
    """Validate model class file path.

    Args:
        v: model class file path

    Returns:
        model class file path

    Raises:
        ValueError: if model class file path is not valid
    """
    if not v:
        raise ValueError("Model class file path is required.")
    if not is_inside_repository(v):
        raise ValueError(
            "Model class file path must be inside the repository."
        )
    return v
torch_config_validate(v) classmethod

Validate torch config file.

Parameters:

Name Type Description Default
v Optional[str]

torch config file path

required

Returns:

Type Description
Optional[str]

torch config file path

Exceptions:

Type Description
ValueError

if torch config file path is not valid.

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("torch_config")
def torch_config_validate(cls, v: Optional[str]) -> Optional[str]:
    """Validate torch config file.

    Args:
        v: torch config file path

    Returns:
        torch config file path

    Raises:
        ValueError: if torch config file path is not valid.
    """
    if v:
        if is_inside_repository(v):
            return v
        else:
            raise ValueError(
                "Torch config file path must be inside the repository."
            )
    return v
kserve_model_deployer_step (BaseStep)

KServe model deployer pipeline step.

This step can be used in a pipeline to implement continuous deployment for an ML model with KServe.

Parameters:

Name Type Description Default
deploy_decision

whether to deploy the model or not

required
config

configuration for the deployer step

required
model

the model artifact to deploy

required
context

the step context

required

Returns:

Type Description

KServe deployment service

CONFIG_CLASS (BaseStepConfig) pydantic-model

KServe model deployer step configuration.

Attributes:

Name Type Description
service_config KServeDeploymentConfig

KServe deployment service configuration.

torch_serve_params

TorchServe set of parameters to deploy model.

timeout int

Timeout for model deployment.

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepConfig(BaseStepConfig):
    """KServe model deployer step configuration.

    Attributes:
        service_config: KServe deployment service configuration.
        torch_serve_params: TorchServe set of parameters to deploy model.
        timeout: Timeout for model deployment.
    """

    service_config: KServeDeploymentConfig
    torch_serve_parameters: Optional[TorchServeParameters] = None
    timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, config, context, model) staticmethod

KServe model deployer pipeline step.

This step can be used in a pipeline to implement continuous deployment for an ML model with KServe.

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

required
config KServeDeployerStepConfig

configuration for the deployer step

required
model ModelArtifact

the model artifact to deploy

required
context StepContext

the step context

required

Returns:

Type Description
KServeDeploymentService

KServe deployment service

Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@step(enable_cache=False)
def kserve_model_deployer_step(
    deploy_decision: bool,
    config: KServeDeployerStepConfig,
    context: StepContext,
    model: ModelArtifact,
) -> KServeDeploymentService:
    """KServe model deployer pipeline step.

    This step can be used in a pipeline to implement continuous
    deployment for an ML model with KServe.

    Args:
        deploy_decision: whether to deploy the model or not
        config: configuration for the deployer step
        model: the model artifact to deploy
        context: the step context

    Returns:
        KServe deployment service
    """
    model_deployer = KServeModelDeployer.get_active_model_deployer()

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    pipeline_run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    # update the step configuration with the real pipeline runtime information
    config.service_config.pipeline_name = pipeline_name
    config.service_config.pipeline_run_id = pipeline_run_id
    config.service_config.pipeline_step_name = step_name

    # fetch existing services with same pipeline name, step name and
    # model name
    existing_services = model_deployer.find_model_server(
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        model_name=config.service_config.model_name,
    )

    # even when the deploy decision is negative if an existing model server
    # is not running for this pipeline/step, we still have to serve the
    # current model, to ensure that a model server is available at all times
    if not deploy_decision and existing_services:
        logger.info(
            f"Skipping model deployment because the model quality does not "
            f"meet the criteria. Reusing the last model server deployed by step "
            f"'{step_name}' and pipeline '{pipeline_name}' for model "
            f"'{config.service_config.model_name}'..."
        )
        service = cast(KServeDeploymentService, existing_services[0])
        # even when the deploy decision is negative, we still need to start
        # the previous model server if it is no longer running, to ensure that
        # a model server is available at all times
        if not service.is_running:
            service.start(timeout=config.timeout)
        return service

    # invoke the KServe model deployer to create a new service
    # or update an existing one that was previously deployed for the same
    # model
    if config.service_config.predictor == "pytorch":
        # import the prepare function from the step utils
        from zenml.integrations.kserve.steps.kserve_step_utils import (
            prepare_torch_service_config,
        )

        # prepare the service config
        service_config = prepare_torch_service_config(
            model_uri=model.uri,
            output_artifact_uri=context.get_output_artifact_uri(),
            config=config,
        )
    else:
        # import the prepare function from the step utils
        from zenml.integrations.kserve.steps.kserve_step_utils import (
            prepare_service_config,
        )

        # prepare the service config
        service_config = prepare_service_config(
            model_uri=model.uri,
            output_artifact_uri=context.get_output_artifact_uri(),
            config=config,
        )
    service = cast(
        KServeDeploymentService,
        model_deployer.deploy_model(
            service_config, replace=True, timeout=config.timeout
        ),
    )

    logger.info(
        f"KServe deployment service started and reachable at:\n"
        f"    {service.prediction_url}\n"
        f"    With the hostname: {service.prediction_hostname}."
    )

    return service
kserve_step_utils

This module contains the utility functions used by the KServe deployer step.

TorchModelArchiver (BaseModel) pydantic-model

Model Archiver for PyTorch models.

Attributes:

Name Type Description
model_name str

Model name.

model_version

Model version.

serialized_file str

Serialized model file.

handler str

TorchServe's handler file to handle custom TorchServe inference logic.

extra_files Optional[List[str]]

Comma separated path to extra dependency files.

requirements_file Optional[str]

Path to requirements file.

export_path str

Path to export model.

runtime Optional[str]

Runtime of the model.

force Optional[bool]

Force export of the model.

archive_format Optional[str]

Archive format.

Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
class TorchModelArchiver(BaseModel):
    """Model Archiver for PyTorch models.

    Attributes:
        model_name: Model name.
        model_version: Model version.
        serialized_file: Serialized model file.
        handler: TorchServe's handler file to handle custom TorchServe inference logic.
        extra_files: Comma separated path to extra dependency files.
        requirements_file: Path to requirements file.
        export_path: Path to export model.
        runtime: Runtime of the model.
        force: Force export of the model.
        archive_format: Archive format.
    """

    model_name: str
    serialized_file: str
    model_file: str
    handler: str
    export_path: str
    extra_files: Optional[List[str]] = None
    version: Optional[str] = None
    requirements_file: Optional[str] = None
    runtime: Optional[str] = "python"
    force: Optional[bool] = None
    archive_format: Optional[str] = "default"
generate_model_deployer_config(model_name, directory)

Generate a model deployer config.

Parameters:

Name Type Description Default
model_name str

the name of the model

required
directory str

the directory where the model is stored

required

Returns:

Type Description
str

None

Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def generate_model_deployer_config(
    model_name: str,
    directory: str,
) -> str:
    """Generate a model deployer config.

    Args:
        model_name: the name of the model
        directory: the directory where the model is stored

    Returns:
        None
    """
    config_lines = [
        "inference_address=http://0.0.0.0:8085",
        "management_address=http://0.0.0.0:8085",
        "metrics_address=http://0.0.0.0:8082",
        "grpc_inference_port=7070",
        "grpc_management_port=7071",
        "enable_metrics_api=true",
        "metrics_format=prometheus",
        "number_of_netty_threads=4",
        "job_queue_size=10",
        "enable_envvars_config=true",
        "install_py_dep_per_model=true",
        "model_store=/mnt/models/model-store",
    ]

    with tempfile.NamedTemporaryFile(
        suffix=".properties", mode="w+", dir=directory, delete=False
    ) as f:
        for line in config_lines:
            f.write(line + "\n")
        f.write(
            f'model_snapshot={{"name":"startup.cfg","modelCount":1,"models":{{"{model_name}":{{"1.0":{{"defaultVersion":true,"marName":"{model_name}.mar","minWorkers":1,"maxWorkers":5,"batchSize":1,"maxBatchDelay":10,"responseTimeout":120}}}}}}}}'
        )
    f.close()
    return f.name
prepare_service_config(model_uri, output_artifact_uri, config)

Prepare the model files for model serving.

This function ensures that the model files are in the correct format and file structure required by the KServe server implementation used for model serving.

Parameters:

Name Type Description Default
model_uri str

the URI of the model artifact being served

required
output_artifact_uri str

the URI of the output artifact

required
config KServeDeployerStepConfig

the KServe deployer step config

required

Returns:

Type Description
KServeDeploymentConfig

The URL to the model is ready for serving.

Exceptions:

Type Description
RuntimeError

if the model files cannot be prepared.

Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def prepare_service_config(
    model_uri: str, output_artifact_uri: str, config: KServeDeployerStepConfig
) -> KServeDeploymentConfig:
    """Prepare the model files for model serving.

    This function ensures that the model files are in the correct format
    and file structure required by the KServe server implementation
    used for model serving.

    Args:
        model_uri: the URI of the model artifact being served
        output_artifact_uri: the URI of the output artifact
        config: the KServe deployer step config

    Returns:
        The URL to the model is ready for serving.

    Raises:
        RuntimeError: if the model files cannot be prepared.
    """
    served_model_uri = os.path.join(output_artifact_uri, "kserve")
    fileio.makedirs(served_model_uri)

    # TODO [ENG-773]: determine how to formalize how models are organized into
    #   folders and sub-folders depending on the model type/format and the
    #   KServe protocol used to serve the model.

    # TODO [ENG-791]: an auto-detect built-in KServe server implementation
    #   from the model artifact type

    # TODO [ENG-792]: validate the model artifact type against the
    #   supported built-in KServe server implementations
    if config.service_config.predictor == "tensorflow":
        # the TensorFlow server expects model artifacts to be
        # stored in numbered subdirectories, each representing a model
        # version
        served_model_uri = os.path.join(
            served_model_uri,
            config.service_config.predictor,
            config.service_config.model_name,
        )
        fileio.makedirs(served_model_uri)
        io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
    elif config.service_config.predictor == "sklearn":
        # the sklearn server expects model artifacts to be
        # stored in a file called model.joblib
        model_uri = os.path.join(model_uri, "model")
        if not fileio.exists(model_uri):
            raise RuntimeError(
                f"Expected sklearn model artifact was not found at "
                f"{model_uri}"
            )
        served_model_uri = os.path.join(
            served_model_uri,
            config.service_config.predictor,
            config.service_config.model_name,
        )
        fileio.makedirs(served_model_uri)
        fileio.copy(model_uri, os.path.join(served_model_uri, "model.joblib"))
    else:
        # default treatment for all other server implementations is to
        # simply reuse the model from the artifact store path where it
        # is originally stored
        served_model_uri = os.path.join(
            served_model_uri,
            config.service_config.predictor,
            config.service_config.model_name,
        )
        fileio.makedirs(served_model_uri)
        fileio.copy(model_uri, served_model_uri)

    service_config = config.service_config.copy()
    service_config.model_uri = served_model_uri
    return service_config
prepare_torch_service_config(model_uri, output_artifact_uri, config)

Prepare the PyTorch model files for model serving.

This function ensures that the model files are in the correct format and file structure required by the KServe server implementation used for model serving.

Parameters:

Name Type Description Default
model_uri str

the URI of the model artifact being served

required
output_artifact_uri str

the URI of the output artifact

required
config KServeDeployerStepConfig

the KServe deployer step config

required

Returns:

Type Description
KServeDeploymentConfig

The URL to the model is ready for serving.

Exceptions:

Type Description
RuntimeError

if the model files cannot be prepared.

Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def prepare_torch_service_config(
    model_uri: str, output_artifact_uri: str, config: KServeDeployerStepConfig
) -> KServeDeploymentConfig:
    """Prepare the PyTorch model files for model serving.

    This function ensures that the model files are in the correct format
    and file structure required by the KServe server implementation
    used for model serving.

    Args:
        model_uri: the URI of the model artifact being served
        output_artifact_uri: the URI of the output artifact
        config: the KServe deployer step config

    Returns:
        The URL to the model is ready for serving.

    Raises:
        RuntimeError: if the model files cannot be prepared.
    """
    deployment_folder_uri = os.path.join(output_artifact_uri, "kserve")
    served_model_uri = os.path.join(deployment_folder_uri, "model-store")
    config_propreties_uri = os.path.join(deployment_folder_uri, "config")
    fileio.makedirs(served_model_uri)
    fileio.makedirs(config_propreties_uri)

    if config.torch_serve_parameters is None:
        raise RuntimeError("No torch serve parameters provided")
    else:
        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-pytorch-temp-")
        tmp_model_uri = os.path.join(
            str(temp_dir), f"{config.service_config.model_name}.pt"
        )

        # Copy from artifact store to temporary file
        fileio.copy(f"{model_uri}/checkpoint.pt", tmp_model_uri)

        torch_archiver_args = TorchModelArchiver(
            model_name=config.service_config.model_name,
            serialized_file=tmp_model_uri,
            model_file=config.torch_serve_parameters.model_class,
            handler=config.torch_serve_parameters.handler,
            export_path=temp_dir,
            version=config.torch_serve_parameters.model_version,
        )

        manifest = ModelExportUtils.generate_manifest_json(torch_archiver_args)
        package_model(torch_archiver_args, manifest=manifest)

        # Copy from temporary file to artifact store
        archived_model_uri = os.path.join(
            temp_dir, f"{config.service_config.model_name}.mar"
        )
        if not fileio.exists(archived_model_uri):
            raise RuntimeError(
                f"Expected torch archived model artifact was not found at "
                f"{archived_model_uri}"
            )

        # Copy the torch model archive artifact to the model store
        fileio.copy(
            archived_model_uri,
            os.path.join(
                served_model_uri, f"{config.service_config.model_name}.mar"
            ),
        )

        # Get or Generate the config file
        if config.torch_serve_parameters.torch_config:
            # Copy the torch model config to the model store
            fileio.copy(
                config.torch_serve_parameters.torch_config,
                os.path.join(config_propreties_uri, "config.properties"),
            )
        else:
            # Generate the config file
            config_file_uri = generate_model_deployer_config(
                model_name=config.service_config.model_name,
                directory=temp_dir,
            )
            # Copy the torch model config to the model store
            fileio.copy(
                config_file_uri,
                os.path.join(config_propreties_uri, "config.properties"),
            )

    service_config = config.service_config.copy()
    service_config.model_uri = deployment_folder_uri
    return service_config

kubeflow special

Initialization of the Kubeflow integration for ZenML.

The Kubeflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool.

KubeflowIntegration (Integration)

Definition of Kubeflow Integration for ZenML.

Source code in zenml/integrations/kubeflow/__init__.py
class KubeflowIntegration(Integration):
    """Definition of Kubeflow Integration for ZenML."""

    NAME = KUBEFLOW
    REQUIREMENTS = ["kfp==1.8.9"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Kubeflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=KUBEFLOW_METADATA_STORE_FLAVOR,
                source="zenml.integrations.kubeflow.metadata_stores.KubeflowMetadataStore",
                type=StackComponentType.METADATA_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=KUBEFLOW_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.kubeflow.orchestrators.KubeflowOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the Kubeflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/kubeflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Kubeflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=KUBEFLOW_METADATA_STORE_FLAVOR,
            source="zenml.integrations.kubeflow.metadata_stores.KubeflowMetadataStore",
            type=StackComponentType.METADATA_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=KUBEFLOW_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.kubeflow.orchestrators.KubeflowOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
    ]

metadata_stores special

Initialization of the Kubeflow metadata store for ZenML.

kubeflow_metadata_store

Implementation of the Kubeflow metadata store.

KubeflowMetadataStore (BaseMetadataStore) pydantic-model

Kubeflow GRPC backend for ZenML metadata store.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
class KubeflowMetadataStore(BaseMetadataStore):
    """Kubeflow GRPC backend for ZenML metadata store."""

    upgrade_migration_enabled: bool = False
    host: str = "127.0.0.1"
    port: int = DEFAULT_KFP_METADATA_GRPC_PORT

    # Class Configuration
    FLAVOR: ClassVar[str] = KUBEFLOW_METADATA_STORE_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a KFP orchestrator.

        Returns:
            The stack validator.
        """

        def _ensure_kfp_orchestrator(stack: Stack) -> Tuple[bool, str]:
            return (
                stack.orchestrator.FLAVOR == KUBEFLOW,
                "The Kubeflow metadata store can only be used with a Kubeflow "
                "orchestrator.",
            )

        return StackValidator(
            custom_validation_function=_ensure_kfp_orchestrator
        )

    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config for the kubeflow metadata store.

        Returns:
            The tfx metadata config for the kubeflow metadata store.

        Raises:
            RuntimeError: If the metadata store is not running.
        """
        connection_config = metadata_store_pb2.MetadataStoreClientConfig()
        if inside_kfp_pod():
            connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
            connection_config.port = int(
                os.environ["METADATA_GRPC_SERVICE_PORT"]
            )
        else:
            if not self.is_running:
                raise RuntimeError(
                    "The KFP metadata daemon is not running. Please run the "
                    "following command to start it first:\n\n"
                    "    'zenml metadata-store up'\n"
                )
            connection_config.host = self.host
            connection_config.port = self.port
        return connection_config

    @property
    def kfp_orchestrator(self) -> KubeflowOrchestrator:
        """Returns the Kubeflow orchestrator in the active stack.

        Returns:
            The Kubeflow orchestrator in the active stack.
        """
        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        return cast(KubeflowOrchestrator, repo.active_stack.orchestrator)

    @property
    def kubernetes_context(self) -> str:
        """Returns the kubernetes context.

        This is returned to the cluster where the Kubeflow Pipelines services
        are running.

        Returns:
            The kubernetes context.
        """
        kubernetes_context = self.kfp_orchestrator.kubernetes_context

        # will never happen, but mypy doesn't know that
        assert kubernetes_context is not None
        return kubernetes_context

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory.

        This is for all files concerning this KFP metadata store.

        Note: the root directory for the KFP metadata store is relative to the
        root directory of the KFP orchestrator, because it is a sub-component
        of it.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            self.kfp_orchestrator.root_directory,
            "metadata-store",
            str(self.uuid),
        )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.pid")

    @property
    def _log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path to the daemon log file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.log")

    @property
    def is_provisioned(self) -> bool:
        """If the component provisioned resources to run locally.

        Returns:
            True if the component provisioned resources to run locally.
        """
        return fileio.exists(self.root_directory)

    @property
    def is_running(self) -> bool:
        """If the component is running locally.

        Returns:
            True if the component is running locally, False otherwise.
        """
        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            if not check_if_daemon_is_running(self._pid_file_path):
                return False
        else:
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass

        return True

    def provision(self) -> None:
        """Provisions resources to run the component locally."""
        logger.info("Provisioning local Kubeflow Pipelines deployment...")
        fileio.makedirs(self.root_directory)

    def deprovision(self) -> None:
        """Deprovisions all local resources of the component."""
        if fileio.exists(self._log_file):
            fileio.remove(self._log_file)

        logger.info("Local kubeflow pipelines deployment deprovisioned.")

    def resume(self) -> None:
        """Resumes the local k3d cluster."""
        if self.is_running:
            logger.info("Local kubeflow pipelines deployment already running.")
            return

        self.start_kfp_metadata_daemon()
        self.wait_until_metadata_store_ready()

    def suspend(self) -> None:
        """Suspends the local k3d cluster."""
        if not self.is_running:
            logger.info("Local kubeflow pipelines deployment not running.")
            return

        self.stop_kfp_metadata_daemon()

    def start_kfp_metadata_daemon(self) -> None:
        """Starts a daemon process that forwards ports.

        This is so the Kubeflow Pipelines Metadata MySQL database is accessible
        on the localhost.

        Raises:
            ProvisioningError: if the daemon fails to start.
        """
        command = [
            "kubectl",
            "--context",
            self.kubernetes_context,
            "--namespace",
            "kubeflow",
            "port-forward",
            "svc/metadata-grpc-service",
            f"{self.port}:8080",
        ]

        if sys.platform == "win32":
            logger.warning(
                "Daemon functionality not supported on Windows. "
                "In order to access the Kubeflow Pipelines Metadata locally, "
                "please run '%s' in a separate command line shell.",
                self.port,
                " ".join(command),
            )
        elif not networking_utils.port_available(self.port):
            raise ProvisioningError(
                f"Unable to port-forward Kubeflow Pipelines Metadata to local "
                f"port {self.port} because the port is occupied. In order to "
                f"access the Kubeflow Pipelines Metadata locally, please "
                f"change the metadata store configuration to use an available "
                f"port or stop the other process currently using the port."
            )
        else:
            from zenml.utils import daemon

            def _daemon_function() -> None:
                """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
                subprocess.check_call(command)

            daemon.run_as_daemon(
                _daemon_function,
                pid_file=self._pid_file_path,
                log_file=self._log_file,
            )
            logger.info(
                "Started Kubeflow Pipelines Metadata daemon (check the daemon"
                "logs at %s in case you're not able to access the pipeline"
                "metadata).",
                self._log_file,
            )

    def stop_kfp_metadata_daemon(self) -> None:
        """Stops the KFP Metadata daemon process if it is running."""
        if fileio.exists(self._pid_file_path):
            if sys.platform == "win32":
                # Daemon functionality is not supported on Windows, so the PID
                # file won't exist. This if clause exists just for mypy to not
                # complain about missing functions
                pass
            else:
                from zenml.utils import daemon

                daemon.stop_daemon(self._pid_file_path)
                fileio.remove(self._pid_file_path)

    def wait_until_metadata_store_ready(
        self, timeout: int = DEFAULT_KFP_METADATA_DAEMON_TIMEOUT
    ) -> None:
        """Waits until the metadata store connection is ready.

        Potentially an irrecoverable error could occur or the timeout could
        expire, so it checks for this.

        Args:
            timeout: The maximum time to wait for the metadata store to be
                ready.

        Raises:
            RuntimeError: if the metadata store is not ready after the timeout
        """
        logger.info(
            "Waiting for the Kubeflow metadata store to be ready (this might "
            "take a few minutes)."
        )
        while True:
            try:
                # it doesn't matter what we call here as long as it exercises
                # the MLMD connection
                self.get_pipelines()
                break
            except Exception as e:
                logger.info(
                    "The Kubeflow metadata store is not ready yet. Waiting for "
                    "10 seconds..."
                )
                if timeout <= 0:
                    raise RuntimeError(
                        f"An unexpected error was encountered while waiting for the "
                        f"Kubeflow metadata store to be functional: {str(e)}"
                    ) from e
                timeout -= 10
                time.sleep(10)

        logger.info("The Kubeflow metadata store is functional.")
is_provisioned: bool property readonly

If the component provisioned resources to run locally.

Returns:

Type Description
bool

True if the component provisioned resources to run locally.

is_running: bool property readonly

If the component is running locally.

Returns:

Type Description
bool

True if the component is running locally, False otherwise.

kfp_orchestrator: KubeflowOrchestrator property readonly

Returns the Kubeflow orchestrator in the active stack.

Returns:

Type Description
KubeflowOrchestrator

The Kubeflow orchestrator in the active stack.

kubernetes_context: str property readonly

Returns the kubernetes context.

This is returned to the cluster where the Kubeflow Pipelines services are running.

Returns:

Type Description
str

The kubernetes context.

root_directory: str property readonly

Returns path to the root directory.

This is for all files concerning this KFP metadata store.

Note: the root directory for the KFP metadata store is relative to the root directory of the KFP orchestrator, because it is a sub-component of it.

Returns:

Type Description
str

Path to the root directory.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a KFP orchestrator.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

The stack validator.

deprovision(self)

Deprovisions all local resources of the component.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def deprovision(self) -> None:
    """Deprovisions all local resources of the component."""
    if fileio.exists(self._log_file):
        fileio.remove(self._log_file)

    logger.info("Local kubeflow pipelines deployment deprovisioned.")
get_tfx_metadata_config(self)

Return tfx metadata config for the kubeflow metadata store.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

The tfx metadata config for the kubeflow metadata store.

Exceptions:

Type Description
RuntimeError

If the metadata store is not running.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config for the kubeflow metadata store.

    Returns:
        The tfx metadata config for the kubeflow metadata store.

    Raises:
        RuntimeError: If the metadata store is not running.
    """
    connection_config = metadata_store_pb2.MetadataStoreClientConfig()
    if inside_kfp_pod():
        connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
        connection_config.port = int(
            os.environ["METADATA_GRPC_SERVICE_PORT"]
        )
    else:
        if not self.is_running:
            raise RuntimeError(
                "The KFP metadata daemon is not running. Please run the "
                "following command to start it first:\n\n"
                "    'zenml metadata-store up'\n"
            )
        connection_config.host = self.host
        connection_config.port = self.port
    return connection_config
provision(self)

Provisions resources to run the component locally.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def provision(self) -> None:
    """Provisions resources to run the component locally."""
    logger.info("Provisioning local Kubeflow Pipelines deployment...")
    fileio.makedirs(self.root_directory)
resume(self)

Resumes the local k3d cluster.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def resume(self) -> None:
    """Resumes the local k3d cluster."""
    if self.is_running:
        logger.info("Local kubeflow pipelines deployment already running.")
        return

    self.start_kfp_metadata_daemon()
    self.wait_until_metadata_store_ready()
start_kfp_metadata_daemon(self)

Starts a daemon process that forwards ports.

This is so the Kubeflow Pipelines Metadata MySQL database is accessible on the localhost.

Exceptions:

Type Description
ProvisioningError

if the daemon fails to start.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def start_kfp_metadata_daemon(self) -> None:
    """Starts a daemon process that forwards ports.

    This is so the Kubeflow Pipelines Metadata MySQL database is accessible
    on the localhost.

    Raises:
        ProvisioningError: if the daemon fails to start.
    """
    command = [
        "kubectl",
        "--context",
        self.kubernetes_context,
        "--namespace",
        "kubeflow",
        "port-forward",
        "svc/metadata-grpc-service",
        f"{self.port}:8080",
    ]

    if sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Kubeflow Pipelines Metadata locally, "
            "please run '%s' in a separate command line shell.",
            self.port,
            " ".join(command),
        )
    elif not networking_utils.port_available(self.port):
        raise ProvisioningError(
            f"Unable to port-forward Kubeflow Pipelines Metadata to local "
            f"port {self.port} because the port is occupied. In order to "
            f"access the Kubeflow Pipelines Metadata locally, please "
            f"change the metadata store configuration to use an available "
            f"port or stop the other process currently using the port."
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=self._pid_file_path,
            log_file=self._log_file,
        )
        logger.info(
            "Started Kubeflow Pipelines Metadata daemon (check the daemon"
            "logs at %s in case you're not able to access the pipeline"
            "metadata).",
            self._log_file,
        )
stop_kfp_metadata_daemon(self)

Stops the KFP Metadata daemon process if it is running.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def stop_kfp_metadata_daemon(self) -> None:
    """Stops the KFP Metadata daemon process if it is running."""
    if fileio.exists(self._pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path)
            fileio.remove(self._pid_file_path)
suspend(self)

Suspends the local k3d cluster.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def suspend(self) -> None:
    """Suspends the local k3d cluster."""
    if not self.is_running:
        logger.info("Local kubeflow pipelines deployment not running.")
        return

    self.stop_kfp_metadata_daemon()
wait_until_metadata_store_ready(self, timeout=60)

Waits until the metadata store connection is ready.

Potentially an irrecoverable error could occur or the timeout could expire, so it checks for this.

Parameters:

Name Type Description Default
timeout int

The maximum time to wait for the metadata store to be ready.

60

Exceptions:

Type Description
RuntimeError

if the metadata store is not ready after the timeout

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def wait_until_metadata_store_ready(
    self, timeout: int = DEFAULT_KFP_METADATA_DAEMON_TIMEOUT
) -> None:
    """Waits until the metadata store connection is ready.

    Potentially an irrecoverable error could occur or the timeout could
    expire, so it checks for this.

    Args:
        timeout: The maximum time to wait for the metadata store to be
            ready.

    Raises:
        RuntimeError: if the metadata store is not ready after the timeout
    """
    logger.info(
        "Waiting for the Kubeflow metadata store to be ready (this might "
        "take a few minutes)."
    )
    while True:
        try:
            # it doesn't matter what we call here as long as it exercises
            # the MLMD connection
            self.get_pipelines()
            break
        except Exception as e:
            logger.info(
                "The Kubeflow metadata store is not ready yet. Waiting for "
                "10 seconds..."
            )
            if timeout <= 0:
                raise RuntimeError(
                    f"An unexpected error was encountered while waiting for the "
                    f"Kubeflow metadata store to be functional: {str(e)}"
                ) from e
            timeout -= 10
            time.sleep(10)

    logger.info("The Kubeflow metadata store is functional.")
inside_kfp_pod()

Returns if the current python process is running inside a KFP Pod.

Returns:

Type Description
bool

True if the current python process is running inside a KFP Pod, False otherwise.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def inside_kfp_pod() -> bool:
    """Returns if the current python process is running inside a KFP Pod.

    Returns:
        True if the current python process is running inside a KFP Pod, False otherwise.
    """
    if "KFP_POD_NAME" not in os.environ:
        return False

    try:
        k8s_config.load_incluster_config()
        return True
    except k8s_config.ConfigException:
        return False

orchestrators special

Initialization of the Kubeflow ZenML orchestrator.

kubeflow_entrypoint_configuration

Implementation of the Kubeflow entrypoint configuration.

KubeflowEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on kubeflow.

This class writes a markdown file that will be displayed in the KFP UI.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
class KubeflowEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on kubeflow.

    This class writes a markdown file that will be displayed in the KFP UI.
    """

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """Kubeflow specific entrypoint options.

        The metadata ui path option expects a path where the markdown file
        that will be displayed in the kubeflow UI should be written. The same
        path needs to be added as an output artifact called
        `mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.

        Returns:
            The set of custom entrypoint options.
        """
        return {METADATA_UI_PATH_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: BaseStep, *args: Any, **kwargs: Any
    ) -> List[str]:
        """Sets the metadata ui path argument to the value passed in via the keyword args.

        Args:
            step: The step that is being executed.
            *args: The positional arguments passed to the step.
            **kwargs: The keyword arguments passed to the step.

        Returns:
            A list of strings that will be used as arguments to the step.
        """
        return [
            f"--{METADATA_UI_PATH_OPTION}",
            kwargs[METADATA_UI_PATH_OPTION],
        ]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the Kubeflow pipeline run name.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The Kubeflow pipeline run name.
        """
        k8s_config.load_incluster_config()
        run_id = os.environ["KFP_RUN_ID"]
        return kfp.Client().get_run(run_id).run.name  # type: ignore[no-any-return]

    def post_run(
        self,
        pipeline_name: str,
        step_name: str,
        pipeline_node: Pb2PipelineNode,
        execution_info: Optional[data_types.ExecutionInfo] = None,
    ) -> None:
        """Writes a markdown file that will display information.

        This will be about the step execution and input/output artifacts in the
        KFP UI.

        Args:
            pipeline_name: The name of the pipeline.
            step_name: The name of the step.
            pipeline_node: The pipeline node that is being executed.
            execution_info: The execution info of the step.
        """
        if execution_info:
            utils.dump_ui_metadata(
                node=pipeline_node,
                execution_info=execution_info,
                metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
            )
get_custom_entrypoint_arguments(step, *args, **kwargs) classmethod

Sets the metadata ui path argument to the value passed in via the keyword args.

Parameters:

Name Type Description Default
step BaseStep

The step that is being executed.

required
*args Any

The positional arguments passed to the step.

()
**kwargs Any

The keyword arguments passed to the step.

{}

Returns:

Type Description
List[str]

A list of strings that will be used as arguments to the step.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: BaseStep, *args: Any, **kwargs: Any
) -> List[str]:
    """Sets the metadata ui path argument to the value passed in via the keyword args.

    Args:
        step: The step that is being executed.
        *args: The positional arguments passed to the step.
        **kwargs: The keyword arguments passed to the step.

    Returns:
        A list of strings that will be used as arguments to the step.
    """
    return [
        f"--{METADATA_UI_PATH_OPTION}",
        kwargs[METADATA_UI_PATH_OPTION],
    ]
get_custom_entrypoint_options() classmethod

Kubeflow specific entrypoint options.

The metadata ui path option expects a path where the markdown file that will be displayed in the kubeflow UI should be written. The same path needs to be added as an output artifact called mlpipeline-ui-metadata for the corresponding kfp.dsl.ContainerOp.

Returns:

Type Description
Set[str]

The set of custom entrypoint options.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """Kubeflow specific entrypoint options.

    The metadata ui path option expects a path where the markdown file
    that will be displayed in the kubeflow UI should be written. The same
    path needs to be added as an output artifact called
    `mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.

    Returns:
        The set of custom entrypoint options.
    """
    return {METADATA_UI_PATH_OPTION}
get_run_name(self, pipeline_name)

Returns the Kubeflow pipeline run name.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The Kubeflow pipeline run name.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the Kubeflow pipeline run name.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The Kubeflow pipeline run name.
    """
    k8s_config.load_incluster_config()
    run_id = os.environ["KFP_RUN_ID"]
    return kfp.Client().get_run(run_id).run.name  # type: ignore[no-any-return]
post_run(self, pipeline_name, step_name, pipeline_node, execution_info=None)

Writes a markdown file that will display information.

This will be about the step execution and input/output artifacts in the KFP UI.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required
step_name str

The name of the step.

required
pipeline_node PipelineNode

The pipeline node that is being executed.

required
execution_info Optional[tfx.orchestration.portable.data_types.ExecutionInfo]

The execution info of the step.

None
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def post_run(
    self,
    pipeline_name: str,
    step_name: str,
    pipeline_node: Pb2PipelineNode,
    execution_info: Optional[data_types.ExecutionInfo] = None,
) -> None:
    """Writes a markdown file that will display information.

    This will be about the step execution and input/output artifacts in the
    KFP UI.

    Args:
        pipeline_name: The name of the pipeline.
        step_name: The name of the step.
        pipeline_node: The pipeline node that is being executed.
        execution_info: The execution info of the step.
    """
    if execution_info:
        utils.dump_ui_metadata(
            node=pipeline_node,
            execution_info=execution_info,
            metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
        )
kubeflow_orchestrator

Implementation of the Kubeflow orchestrator.

KubeflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Kubeflow.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of a docker image that should be used as the base for the image that will be run on KFP pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

kubeflow_pipelines_ui_port int

A local port to which the KFP UI will be forwarded.

kubeflow_hostname Optional[str]

The hostname to use to talk to the Kubeflow Pipelines API. If not set, the hostname will be derived from the Kubernetes API proxy.

kubernetes_context Optional[str]

Optional name of a kubernetes context to run pipelines in. If not set, the current active context will be used. You can find the active context by running kubectl config current-context.

synchronous bool

If True, running a pipeline using this orchestrator will block until all steps finished running on KFP.

skip_local_validations bool

If True, the local validations will be skipped.

skip_cluster_provisioning bool

If True, the k3d cluster provisioning will be skipped.

skip_ui_daemon_provisioning bool

If True, provisioning the KFP UI daemon will be skipped.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
class KubeflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Kubeflow.

    Attributes:
        custom_docker_base_image_name: Name of a docker image that should be
            used as the base for the image that will be run on KFP pods. If no
            custom image is given, a basic image of the active ZenML version
            will be used. **Note**: This image needs to have ZenML installed,
            otherwise the pipeline execution will fail. For that reason, you
            might want to extend the ZenML docker images found here:
            https://hub.docker.com/r/zenmldocker/zenml/
        kubeflow_pipelines_ui_port: A local port to which the KFP UI will be
            forwarded.
        kubeflow_hostname: The hostname to use to talk to the Kubeflow Pipelines
            API. If not set, the hostname will be derived from the Kubernetes
            API proxy.
        kubernetes_context: Optional name of a kubernetes context to run
            pipelines in. If not set, the current active context will be used.
            You can find the active context by running `kubectl config
            current-context`.
        synchronous: If `True`, running a pipeline using this orchestrator will
            block until all steps finished running on KFP.
        skip_local_validations: If `True`, the local validations will be
            skipped.
        skip_cluster_provisioning: If `True`, the k3d cluster provisioning will
            be skipped.
        skip_ui_daemon_provisioning: If `True`, provisioning the KFP UI daemon
            will be skipped.
    """

    custom_docker_base_image_name: Optional[str] = None
    kubeflow_pipelines_ui_port: int = DEFAULT_KFP_UI_PORT
    kubeflow_hostname: Optional[str] = None
    kubernetes_context: Optional[str] = None
    synchronous: bool = False
    skip_local_validations: bool = False
    skip_cluster_provisioning: bool = False
    skip_ui_daemon_provisioning: bool = False

    # Class Configuration
    FLAVOR: ClassVar[str] = KUBEFLOW_ORCHESTRATOR_FLAVOR

    @staticmethod
    def _get_k3d_cluster_name(uuid: UUID) -> str:
        """Returns the k3d cluster name corresponding to the orchestrator UUID.

        Args:
            uuid: The UUID of the orchestrator.

        Returns:
            The k3d cluster name.
        """
        # k3d only allows cluster names with up to 32 characters; use the
        # first 8 chars of the orchestrator UUID as identifier
        return f"zenml-kubeflow-{str(uuid)[:8]}"

    @staticmethod
    def _get_k3d_kubernetes_context(uuid: UUID) -> str:
        """Gets the k3d kubernetes context.

        Args:
            uuid: The UUID of the orchestrator.

        Returns:
            The name of the kubernetes context associated with the k3d
                cluster managed locally by ZenML corresponding to the orchestrator UUID.
        """
        return f"k3d-{KubeflowOrchestrator._get_k3d_cluster_name(uuid)}"

    @root_validator(skip_on_failure=True)
    def set_default_kubernetes_context(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Pydantic root_validator.

        This sets the default `kubernetes_context` value to the value that is
        used to create the locally managed k3d cluster, if not explicitly set.

        Args:
            values: Values passed to the object constructor

        Returns:
            Values passed to the Pydantic constructor
        """
        if not values.get("kubernetes_context"):
            # not likely, due to Pydantic validation, but mypy complains
            assert "uuid" in values
            values["kubernetes_context"] = cls._get_k3d_kubernetes_context(
                values["uuid"]
            )

        return values

    def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
        """Get the list of configured Kubernetes contexts and the active context.

        Returns:
            A tuple containing the list of configured Kubernetes contexts and
            the active context.
        """
        try:
            contexts, active_context = k8s_config.list_kube_config_contexts()
        except k8s_config.config_exception.ConfigException:
            return [], None

        context_names = [c["name"] for c in contexts]
        active_context_name = active_context["name"]
        return context_names, active_context_name

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Also check that requirements are met for local components.

        Returns:
            A `StackValidator` instance.
        """

        def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:

            container_registry = stack.container_registry

            # should not happen, because the stack validation takes care of
            # this, but just in case
            assert container_registry is not None

            contexts, active_context = self.get_kubernetes_contexts()

            if self.kubernetes_context not in contexts:
                if not self.is_local:
                    return False, (
                        f"Could not find a Kubernetes context named "
                        f"'{self.kubernetes_context}' in the local Kubernetes "
                        f"configuration. Please make sure that the Kubernetes "
                        f"cluster is running and that the kubeconfig file is "
                        f"configured correctly. To list all configured "
                        f"contexts, run:\n\n"
                        f"  `kubectl config get-contexts`\n"
                    )
            elif active_context and self.kubernetes_context != active_context:
                logger.warning(
                    f"The Kubernetes context '{self.kubernetes_context}' "
                    f"configured for the Kubeflow orchestrator is not the "
                    f"same as the active context in the local Kubernetes "
                    f"configuration. If this is not deliberate, you should "
                    f"update the orchestrator's `kubernetes_context` field by "
                    f"running:\n\n"
                    f"  `zenml orchestrator update {self.name} "
                    f"--kubernetes_context={active_context}`\n"
                    f"To list all configured contexts, run:\n\n"
                    f"  `kubectl config get-contexts`\n"
                    f"To set the active context to be the same as the one "
                    f"configured in the Kubeflow orchestrator and silence "
                    f"this warning, run:\n\n"
                    f"  `kubectl config use-context "
                    f"{self.kubernetes_context}`\n"
                )

            silence_local_validations_msg = (
                f"To silence this warning, set the "
                f"`skip_local_validations` attribute to True in the "
                f"orchestrator configuration by running:\n\n"
                f"  'zenml orchestrator update {self.name} "
                f"--skip_local_validations=True'\n"
            )

            if not self.skip_local_validations and not self.is_local:

                # if the orchestrator is not running in a local k3d cluster,
                # we cannot have any other local components in our stack,
                # because we cannot mount the local path into the container.
                # This may result in problems when running the pipeline, because
                # the local components will not be available inside the
                # Kubeflow containers.

                # go through all stack components and identify those that
                # advertise a local path where they persist information that
                # they need to be available when running pipelines.
                for stack_comp in stack.components.values():
                    local_path = stack_comp.local_path
                    if not local_path:
                        continue
                    return False, (
                        f"The Kubeflow orchestrator is configured to run "
                        f"pipelines in a remote Kubernetes cluster designated "
                        f"by the '{self.kubernetes_context}' configuration "
                        f"context, but the '{stack_comp.name}' "
                        f"{stack_comp.TYPE.value} is a local stack component "
                        f"and will not be available in the Kubeflow pipeline "
                        f"step.\nPlease ensure that you always use non-local "
                        f"stack components with a remote Kubeflow orchestrator, "
                        f"otherwise you may run into pipeline execution "
                        f"problems. You should use a flavor of "
                        f"{stack_comp.TYPE.value} other than "
                        f"'{stack_comp.FLAVOR}'.\n"
                        + silence_local_validations_msg
                    )

                # if the orchestrator is remote, the container registry must
                # also be remote.
                if container_registry.is_local:
                    return False, (
                        f"The Kubeflow orchestrator is configured to run "
                        f"pipelines in a remote Kubernetes cluster designated "
                        f"by the '{self.kubernetes_context}' configuration "
                        f"context, but the '{container_registry.name}' "
                        f"container registry URI '{container_registry.uri}' "
                        f"points to a local container registry. Please ensure "
                        f"that you always use non-local stack components with "
                        f"a remote Kubeflow orchestrator, otherwise you will "
                        f"run into problems. You should use a flavor of "
                        f"container registry other than "
                        f"'{container_registry.FLAVOR}'.\n"
                        + silence_local_validations_msg
                    )

            if not self.skip_local_validations and self.is_local:

                # if the orchestrator is local, the container registry must
                # also be local.
                if not container_registry.is_local:
                    return False, (
                        f"The Kubeflow orchestrator is configured to run "
                        f"pipelines in a local k3d Kubernetes cluster "
                        f"designated by the '{self.kubernetes_context}' "
                        f"configuration context, but the container registry "
                        f"URI '{container_registry.uri}' doesn't match the "
                        f"expected format 'localhost:$PORT'. "
                        f"The local Kubeflow orchestrator only works with a "
                        f"local container registry because it cannot "
                        f"currently authenticate to external container "
                        f"registries. You should use a flavor of container "
                        f"registry other than '{container_registry.FLAVOR}'.\n"
                        + silence_local_validations_msg
                    )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_local_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The full docker image name including registry and tag.
        """
        base_image_name = f"zenml-kubeflow:{pipeline_name}"
        container_registry = Repository().active_stack.container_registry

        if container_registry:
            registry_uri = container_registry.uri.rstrip("/")
            return f"{registry_uri}/{base_image_name}"
        else:
            return base_image_name

    @property
    def is_local(self) -> bool:
        """Checks if the KFP orchestrator is running locally.

        Returns:
            `True` if the KFP orchestrator is running locally (i.e. in
            the local k3d cluster managed by ZenML).
        """
        return self.kubernetes_context == self._get_k3d_kubernetes_context(
            self.uuid
        )

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all files concerning this orchestrator.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            io_utils.get_global_config_directory(),
            "kubeflow",
            str(self.uuid),
        )

    @property
    def pipeline_directory(self) -> str:
        """Returns path to a directory in which the kubeflow pipeline files are stored.

        Returns:
            Path to the pipeline directory.
        """
        return os.path.join(self.root_directory, "pipelines")

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Builds a docker image for the current environment.

        This function also uploads it to a container registry if configured.

        Args:
            pipeline: The pipeline to be deployed.
            stack: The stack to be deployed.
            runtime_configuration: The runtime configuration to be used.
        """
        from zenml.utils import docker_utils

        image_name = self.get_docker_image_name(pipeline.name)

        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug("Kubeflow docker container requirements: %s", requirements)

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
            environment_vars=self._get_environment_vars_from_secrets(
                pipeline.secrets
            ),
        )

        assert stack.container_registry  # should never happen due to validation
        stack.container_registry.push_image(image_name)

        # Store the docker image digest in the runtime configuration so it gets
        # tracked in the ZenStore
        image_digest = docker_utils.get_image_digest(image_name) or image_name
        runtime_configuration["docker_image"] = image_digest

    @staticmethod
    def _configure_container_op(container_op: dsl.ContainerOp) -> None:
        """Makes changes in place to the configuration of the container op.

        Configures persistent mounted volumes for each stack component that
        writes to a local path. Adds some labels to the container_op and applies
        some functions to ir.

        Args:
            container_op: The kubeflow container operation to configure.

        Raises:
            ValueError: If the local path is not in the global config directory.
        """
        # Path to a metadata file that will be displayed in the KFP UI
        # This metadata file needs to be in a mounted emptyDir to avoid
        # sporadic failures with the (not mature) PNS executor
        # See these links for more information about limitations of PNS +
        # security context:
        # https://www.kubeflow.org/docs/components/pipelines/installation/localcluster-deployment/#deploying-kubeflow-pipelines
        # https://argoproj.github.io/argo-workflows/empty-dir/
        # KFP will switch to the Emissary executor (soon), when this emptyDir
        # mount will not be necessary anymore, but for now it's still in alpha
        # status (https://www.kubeflow.org/docs/components/pipelines/installation/choose-executor/#emissary-executor)
        volumes: Dict[str, k8s_client.V1Volume] = {
            "/outputs": k8s_client.V1Volume(
                name="outputs", empty_dir=k8s_client.V1EmptyDirVolumeSource()
            ),
        }

        stack = Repository().active_stack
        global_cfg_dir = io_utils.get_global_config_directory()

        # go through all stack components and identify those that advertise
        # a local path where they persist information that they need to be
        # available when running pipelines. For those that do, mount them
        # into the Kubeflow container.
        has_local_repos = False
        for stack_comp in stack.components.values():
            local_path = stack_comp.local_path
            if not local_path:
                continue
            # double-check this convention, just in case it wasn't respected
            # as documented in `StackComponent.local_path`
            if not local_path.startswith(global_cfg_dir):
                raise ValueError(
                    f"Local path {local_path} for component {stack_comp.name} "
                    f"is not in the global config directory ({global_cfg_dir})."
                )
            has_local_repos = True
            host_path = k8s_client.V1HostPathVolumeSource(
                path=local_path, type="Directory"
            )
            volume_name = f"{stack_comp.TYPE.value}-{stack_comp.name}"
            volumes[local_path] = k8s_client.V1Volume(
                name=re.sub(r"[^0-9a-zA-Z-]+", "-", volume_name)
                .strip("-")
                .lower(),
                host_path=host_path,
            )
            logger.debug(
                "Adding host path volume for %s %s (path: %s) "
                "in kubeflow pipelines container.",
                stack_comp.TYPE.value,
                stack_comp.name,
                local_path,
            )
        container_op.add_pvolumes(volumes)

        if has_local_repos:
            if sys.platform == "win32":
                # File permissions are not checked on Windows. This if clause
                # prevents mypy from complaining about unused 'type: ignore'
                # statements
                pass
            else:
                # Run KFP containers in the context of the local UID/GID
                # to ensure that the artifact and metadata stores can be shared
                # with the local pipeline runs.
                container_op.container.security_context = (
                    k8s_client.V1SecurityContext(
                        run_as_user=os.getuid(),
                        run_as_group=os.getgid(),
                    )
                )
                logger.debug(
                    "Setting security context UID and GID to local user/group "
                    "in kubeflow pipelines container."
                )

        # Add environment variables for Azure Blob Storage to pod in case they
        # are set locally
        # TODO [ENG-699]: remove this as soon as we implement credential
        #  handling
        for key in [
            "AZURE_STORAGE_ACCOUNT_KEY",
            "AZURE_STORAGE_ACCOUNT_NAME",
            "AZURE_STORAGE_CONNECTION_STRING",
            "AZURE_STORAGE_SAS_TOKEN",
        ]:
            value = os.getenv(key)
            if value:
                container_op.container.add_env_variable(
                    k8s_client.V1EnvVar(name=key, value=value)
                )

        # Add some pod labels to the container_op
        for k, v in KFP_POD_LABELS.items():
            container_op.add_pod_label(k, v)

        # Mounts configmap containing Metadata gRPC server configuration.
        container_op.apply(utils.mount_config_map_op("metadata-grpc-configmap"))

    @staticmethod
    def _configure_container_resources(
        container_op: dsl.ContainerOp,
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Adds resource requirements to the container.

        Args:
            container_op: The kubeflow container operation to configure.
            resource_configuration: The resource configuration to use for this
                container.
        """
        if resource_configuration.cpu_count is not None:
            container_op = container_op.set_cpu_limit(
                str(resource_configuration.cpu_count)
            )

        if resource_configuration.gpu_count is not None:
            container_op = container_op.set_gpu_limit(
                resource_configuration.gpu_count
            )

        if resource_configuration.memory is not None:
            memory_limit = resource_configuration.memory[:-1]
            container_op = container_op.set_memory_limit(memory_limit)

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List["BaseStep"],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates a kfp yaml file.

        This functions as an intermediary representation of the pipeline which
        is then deployed to the kubeflow pipelines instance.

        How it works:
        -------------
        Before this method is called the `prepare_pipeline_deployment()`
        method builds a docker image that contains the code for the
        pipeline, all steps the context around these files.

        Based on this docker image a callable is created which builds
        container_ops for each step (`_construct_kfp_pipeline`).
        To do this the entrypoint of the docker image is configured to
        run the correct step within the docker image. The dependencies
        between these container_ops are then also configured onto each
        container_op by pointing at the downstream steps.

        This callable is then compiled into a kfp yaml file that is used as
        the intermediary representation of the kubeflow pipeline.

        This file, together with some metadata, runtime configurations is
        then uploaded into the kubeflow pipelines cluster for execution.

        Args:
            sorted_steps: A list of steps sorted by their order in the
                pipeline.
            pipeline: The pipeline object.
            pb2_pipeline: The pipeline object in protobuf format.
            stack: The stack object.
            runtime_configuration: The runtime configuration object.

        Raises:
            RuntimeError: If you try to run the pipelines in a notebook environment.
        """
        # First check whether the code running in a notebook
        if Environment.in_notebook():
            raise RuntimeError(
                "The Kubeflow orchestrator cannot run pipelines in a notebook "
                "environment. The reason is that it is non-trivial to create "
                "a Docker image of a notebook. Please consider refactoring "
                "your notebook cells into separate scripts in a Python module "
                "and run the code outside of a notebook when using this "
                "orchestrator."
            )

        image_name = self.get_docker_image_name(pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        # Create a callable for future compilation into a dsl.Pipeline.
        def _construct_kfp_pipeline() -> None:
            """Create a container_op for each step.

            This should contain the name of the docker image and configures the
            entrypoint of the docker image to run the step.

            Additionally, this gives each container_op information about its
            direct downstream steps.

            If this callable is passed to the `_create_and_write_workflow()`
            method of a KFPCompiler all dsl.ContainerOp instances will be
            automatically added to a singular dsl.Pipeline instance.
            """
            # Dictionary of container_ops index by the associated step name
            step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

            for step in sorted_steps:
                # The command will be needed to eventually call the python step
                # within the docker container
                command = (
                    KubeflowEntrypointConfiguration.get_entrypoint_command()
                )

                # The arguments are passed to configure the entrypoint of the
                # docker container when the step is called.
                metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
                arguments = (
                    KubeflowEntrypointConfiguration.get_entrypoint_arguments(
                        step=step,
                        pb2_pipeline=pb2_pipeline,
                        **{METADATA_UI_PATH_OPTION: metadata_ui_path},
                    )
                )

                # Create a container_op - the kubeflow equivalent of a step. It
                # contains the name of the step, the name of the docker image,
                # the command to use to run the step entrypoint
                # (e.g. `python -m zenml.entrypoints.step_entrypoint`)
                # and the arguments to be passed along with the command. Find
                # out more about how these arguments are parsed and used
                # in the base entrypoint `run()` method.
                container_op = dsl.ContainerOp(
                    name=step.name,
                    image=image_name,
                    command=command,
                    arguments=arguments,
                    output_artifact_paths={
                        "mlpipeline-ui-metadata": metadata_ui_path,
                    },
                )

                # Mounts persistent volumes, configmaps and adds labels to the
                # container op
                self._configure_container_op(container_op=container_op)

                if self.requires_resources_in_orchestration_environment(step):
                    self._configure_container_resources(
                        container_op=container_op,
                        resource_configuration=step.resource_configuration,
                    )

                # Find the upstream container ops of the current step and
                # configure the current container op to run after them
                upstream_step_names = self.get_upstream_step_names(
                    step=step, pb2_pipeline=pb2_pipeline
                )
                for upstream_step_name in upstream_step_names:
                    upstream_container_op = step_name_to_container_op[
                        upstream_step_name
                    ]
                    container_op.after(upstream_container_op)

                # Update dictionary of container ops with the current one
                step_name_to_container_op[step.name] = container_op

        # Get a filepath to use to save the finished yaml to
        assert runtime_configuration.run_name
        fileio.makedirs(self.pipeline_directory)
        pipeline_file_path = os.path.join(
            self.pipeline_directory, f"{runtime_configuration.run_name}.yaml"
        )

        # write the argo pipeline yaml
        KFPCompiler()._create_and_write_workflow(
            pipeline_func=_construct_kfp_pipeline,
            pipeline_name=pipeline.name,
            package_path=pipeline_file_path,
        )

        # using the kfp client uploads the pipeline to kubeflow pipelines and
        # runs it there
        self._upload_and_run_pipeline(
            pipeline_name=pipeline.name,
            pipeline_file_path=pipeline_file_path,
            runtime_configuration=runtime_configuration,
            enable_cache=pipeline.enable_cache,
        )

    def _upload_and_run_pipeline(
        self,
        pipeline_name: str,
        pipeline_file_path: str,
        runtime_configuration: "RuntimeConfiguration",
        enable_cache: bool,
    ) -> None:
        """Tries to upload and run a KFP pipeline.

        Args:
            pipeline_name: Name of the pipeline.
            pipeline_file_path: Path to the pipeline definition file.
            runtime_configuration: Runtime configuration of the pipeline run.
            enable_cache: Whether caching is enabled for this pipeline run.
        """
        try:
            logger.info(
                "Running in kubernetes context '%s'.",
                self.kubernetes_context,
            )

            # upload the pipeline to Kubeflow and start it
            client = kfp.Client(
                host=self.kubeflow_hostname,
                kube_context=self.kubernetes_context,
            )
            if runtime_configuration.schedule:
                try:
                    experiment = client.get_experiment(pipeline_name)
                    logger.info(
                        "A recurring run has already been created with this "
                        "pipeline. Creating new recurring run now.."
                    )
                except (ValueError, ApiException):
                    experiment = client.create_experiment(pipeline_name)
                    logger.info(
                        "Creating a new recurring run for pipeline '%s'.. ",
                        pipeline_name,
                    )
                logger.info(
                    "You can see all recurring runs under the '%s' experiment.'",
                    pipeline_name,
                )

                schedule = runtime_configuration.schedule
                interval_seconds = (
                    schedule.interval_second.seconds
                    if schedule.interval_second
                    else None
                )
                result = client.create_recurring_run(
                    experiment_id=experiment.id,
                    job_name=runtime_configuration.run_name,
                    pipeline_package_path=pipeline_file_path,
                    enable_caching=enable_cache,
                    cron_expression=schedule.cron_expression,
                    start_time=schedule.utc_start_time,
                    end_time=schedule.utc_end_time,
                    interval_second=interval_seconds,
                    no_catchup=not schedule.catchup,
                )

                logger.info("Started recurring run with ID '%s'.", result.id)
            else:
                logger.info(
                    "No schedule detected. Creating a one-off pipeline run.."
                )
                result = client.create_run_from_pipeline_package(
                    pipeline_file_path,
                    arguments={},
                    run_name=runtime_configuration.run_name,
                    enable_caching=enable_cache,
                )
                logger.info(
                    "Started one-off pipeline run with ID '%s'.", result.run_id
                )

                if self.synchronous:
                    # TODO [ENG-698]: Allow configuration of the timeout as a
                    #  runtime option
                    client.wait_for_run_completion(
                        run_id=result.run_id, timeout=1200
                    )
        except urllib3.exceptions.HTTPError as error:
            logger.warning(
                f"Failed to upload Kubeflow pipeline: %s. "
                f"Please make sure your kubernetes config is present and the "
                f"{self.kubernetes_context} kubernetes context is configured "
                f"correctly.",
                error,
            )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.pid")

    @property
    def log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path of the daemon log file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.log")

    @property
    def _k3d_cluster_name(self) -> str:
        """Returns the K3D cluster name.

        Returns:
            The K3D cluster name.
        """
        return self._get_k3d_cluster_name(self.uuid)

    def _get_k3d_registry_name(self, port: int) -> str:
        """Returns the K3D registry name.

        Args:
            port: Port of the registry.

        Returns:
            The registry name.
        """
        return f"k3d-zenml-kubeflow-registry.localhost:{port}"

    @property
    def _k3d_registry_config_path(self) -> str:
        """Returns the path to the K3D registry config yaml.

        Returns:
            str: Path to the K3D registry config yaml.
        """
        return os.path.join(self.root_directory, "k3d_registry.yaml")

    def _get_kfp_ui_daemon_port(self) -> int:
        """Port to use for the KFP UI daemon.

        Returns:
            Port to use for the KFP UI daemon.
        """
        port = self.kubeflow_pipelines_ui_port
        if port == DEFAULT_KFP_UI_PORT and not networking_utils.port_available(
            port
        ):
            # if the user didn't specify a specific port and the default
            # port is occupied, fallback to a random open port
            port = networking_utils.find_available_port()
        return port

    def list_manual_setup_steps(
        self, container_registry_name: str, container_registry_path: str
    ) -> None:
        """Logs manual steps needed to setup the Kubeflow local orchestrator.

        Args:
            container_registry_name: Name of the container registry.
            container_registry_path: Path to the container registry.
        """
        if not self.is_local:
            # Make sure we're not telling users to deploy Kubeflow on their
            # remote clusters
            logger.warning(
                "This Kubeflow orchestrator is configured to use a non-local "
                f"Kubernetes context {self.kubernetes_context}. Manually "
                f"deploying Kubeflow Pipelines is only possible for local "
                f"Kubeflow orchestrators."
            )
            return

        global_config_dir_path = io_utils.get_global_config_directory()
        kubeflow_commands = [
            f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
            f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
            f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
            f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
            f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
        ]

        logger.info(
            "If you wish to spin up this Kubeflow local orchestrator manually, "
            "please enter the following commands:\n"
        )
        logger.info("\n".join(kubeflow_commands))

    @property
    def is_provisioned(self) -> bool:
        """Returns if a local k3d cluster for this orchestrator exists.

        Returns:
            True if a local k3d cluster exists, False otherwise.
        """
        if not local_deployment_utils.check_prerequisites(
            skip_k3d=self.skip_cluster_provisioning or not self.is_local,
            skip_kubectl=self.skip_cluster_provisioning
            and self.skip_ui_daemon_provisioning,
        ):
            # if any prerequisites are missing there is certainly no
            # local deployment running
            return False

        return self.is_cluster_provisioned

    @property
    def is_running(self) -> bool:
        """Checks if the local k3d cluster and UI daemon are both running.

        Returns:
            True if the local k3d cluster and UI daemon for this orchestrator are both running.
        """
        return (
            self.is_provisioned
            and self.is_cluster_running
            and self.is_daemon_running
        )

    @property
    def is_suspended(self) -> bool:
        """Checks if the local k3d cluster and UI daemon are both stopped.

        Returns:
            True if the cluster and daemon for this orchestrator are both stopped, False otherwise.
        """
        return (
            self.is_provisioned
            and (self.skip_cluster_provisioning or not self.is_cluster_running)
            and (self.skip_ui_daemon_provisioning or not self.is_daemon_running)
        )

    @property
    def is_cluster_provisioned(self) -> bool:
        """Returns if the local k3d cluster for this orchestrator is provisioned.

        For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
        this always returns True.

        Returns:
            True if the local k3d cluster is provisioned, False otherwise.
        """
        if self.skip_cluster_provisioning or not self.is_local:
            return True
        return local_deployment_utils.k3d_cluster_exists(
            cluster_name=self._k3d_cluster_name
        )

    @property
    def is_cluster_running(self) -> bool:
        """Returns if the local k3d cluster for this orchestrator is running.

        For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
        this always returns True.

        Returns:
            True if the local k3d cluster is running, False otherwise.
        """
        if self.skip_cluster_provisioning or not self.is_local:
            return True
        return local_deployment_utils.k3d_cluster_running(
            cluster_name=self._k3d_cluster_name
        )

    @property
    def is_daemon_running(self) -> bool:
        """Returns if the local Kubeflow UI daemon for this orchestrator is running.

        Returns:
            True if the daemon is running, False otherwise.
        """
        if self.skip_ui_daemon_provisioning:
            return True

        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            return check_if_daemon_is_running(self._pid_file_path)
        else:
            return True

    def provision(self) -> None:
        """Provisions a local Kubeflow Pipelines deployment.

        Raises:
            ProvisioningError: If the provisioning fails.
        """
        if self.skip_cluster_provisioning:
            return

        if self.is_running:
            logger.info(
                "Found already existing local Kubeflow Pipelines deployment. "
                "If there are any issues with the existing deployment, please "
                "run 'zenml stack down --force' to delete it."
            )
            return

        if not local_deployment_utils.check_prerequisites():
            raise ProvisioningError(
                "Unable to provision local Kubeflow Pipelines deployment: "
                "Please install 'k3d' and 'kubectl' and try again."
            )

        container_registry = Repository().active_stack.container_registry

        # should not happen, because the stack validation takes care of this,
        # but just in case
        assert container_registry is not None

        fileio.makedirs(self.root_directory)

        if not self.is_local:
            # don't provision any resources if using a remote KFP installation
            return

        logger.info("Provisioning local Kubeflow Pipelines deployment...")

        container_registry_port = int(container_registry.uri.split(":")[-1])
        container_registry_name = self._get_k3d_registry_name(
            port=container_registry_port
        )
        local_deployment_utils.write_local_registry_yaml(
            yaml_path=self._k3d_registry_config_path,
            registry_name=container_registry_name,
            registry_uri=container_registry.uri,
        )

        try:
            local_deployment_utils.create_k3d_cluster(
                cluster_name=self._k3d_cluster_name,
                registry_name=container_registry_name,
                registry_config_path=self._k3d_registry_config_path,
            )
            kubernetes_context = self.kubernetes_context

            # will never happen, but mypy doesn't know that
            assert kubernetes_context is not None

            local_deployment_utils.deploy_kubeflow_pipelines(
                kubernetes_context=kubernetes_context
            )

            artifact_store = Repository().active_stack.artifact_store
            if isinstance(artifact_store, LocalArtifactStore):
                local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
                    kubernetes_context=kubernetes_context,
                    local_path=artifact_store.path,
                )
        except Exception as e:
            logger.error(e)
            logger.error(
                "Unable to spin up local Kubeflow Pipelines deployment."
            )

            self.list_manual_setup_steps(
                container_registry_name, self._k3d_registry_config_path
            )
            self.deprovision()

    def deprovision(self) -> None:
        """Deprovisions a local Kubeflow Pipelines deployment."""
        if self.skip_cluster_provisioning:
            return

        if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
            local_deployment_utils.stop_kfp_ui_daemon(
                pid_file_path=self._pid_file_path
            )

        if self.is_local:
            # don't deprovision any resources if using a remote KFP installation
            local_deployment_utils.delete_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

            logger.info("Local kubeflow pipelines deployment deprovisioned.")

        if fileio.exists(self.log_file):
            fileio.remove(self.log_file)

    def resume(self) -> None:
        """Resumes the local k3d cluster.

        Raises:
            ProvisioningError: If the k3d cluster is not provisioned.
        """
        if self.is_running:
            logger.info("Local kubeflow pipelines deployment already running.")
            return

        if not self.is_provisioned:
            raise ProvisioningError(
                "Unable to resume local kubeflow pipelines deployment: No "
                "resources provisioned for local deployment."
            )

        kubernetes_context = self.kubernetes_context

        # will never happen, but mypy doesn't know that
        assert kubernetes_context is not None

        if (
            not self.skip_cluster_provisioning
            and self.is_local
            and not self.is_cluster_running
        ):
            # don't resume any resources if using a remote KFP installation
            local_deployment_utils.start_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

            local_deployment_utils.wait_until_kubeflow_pipelines_ready(
                kubernetes_context=kubernetes_context
            )

        if not self.is_daemon_running:
            local_deployment_utils.start_kfp_ui_daemon(
                pid_file_path=self._pid_file_path,
                log_file_path=self.log_file,
                port=self._get_kfp_ui_daemon_port(),
                kubernetes_context=kubernetes_context,
            )

    def suspend(self) -> None:
        """Suspends the local k3d cluster."""
        if not self.is_provisioned:
            logger.info("Local kubeflow pipelines deployment not provisioned.")
            return

        if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
            local_deployment_utils.stop_kfp_ui_daemon(
                pid_file_path=self._pid_file_path
            )

        if (
            not self.skip_cluster_provisioning
            and self.is_local
            and self.is_cluster_running
        ):
            # don't suspend any resources if using a remote KFP installation
            local_deployment_utils.stop_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

    def _get_environment_vars_from_secrets(
        self, secrets: List[str]
    ) -> Dict[str, str]:
        """Get key-value pairs from list of secrets provided by the user.

        Args:
            secrets: List of secrets provided by the user.

        Returns:
            A dictionary of key-value pairs.

        Raises:
            ProvisioningError: If the stack has no secrets manager.
        """
        environment_vars: Dict[str, str] = {}
        secret_manager = Repository().active_stack.secrets_manager
        if secrets and secret_manager:
            for secret in secrets:
                secret_schema = secret_manager.get_secret(secret)
                environment_vars.update(secret_schema.content)
        elif secrets and not secret_manager:
            raise ProvisioningError(
                "Unable to provision local Kubeflow Pipelines deployment: "
                f"You passed in the following secrets: { ', '.join(secrets) }, "
                "however, no secrets manager is registered for the current "
                "stack."
            )
        else:
            # No secrets provided by the user.
            pass
        return environment_vars
is_cluster_provisioned: bool property readonly

Returns if the local k3d cluster for this orchestrator is provisioned.

For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.

Returns:

Type Description
bool

True if the local k3d cluster is provisioned, False otherwise.

is_cluster_running: bool property readonly

Returns if the local k3d cluster for this orchestrator is running.

For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.

Returns:

Type Description
bool

True if the local k3d cluster is running, False otherwise.

is_daemon_running: bool property readonly

Returns if the local Kubeflow UI daemon for this orchestrator is running.

Returns:

Type Description
bool

True if the daemon is running, False otherwise.

is_local: bool property readonly

Checks if the KFP orchestrator is running locally.

Returns:

Type Description
bool

True if the KFP orchestrator is running locally (i.e. in the local k3d cluster managed by ZenML).

is_provisioned: bool property readonly

Returns if a local k3d cluster for this orchestrator exists.

Returns:

Type Description
bool

True if a local k3d cluster exists, False otherwise.

is_running: bool property readonly

Checks if the local k3d cluster and UI daemon are both running.

Returns:

Type Description
bool

True if the local k3d cluster and UI daemon for this orchestrator are both running.

is_suspended: bool property readonly

Checks if the local k3d cluster and UI daemon are both stopped.

Returns:

Type Description
bool

True if the cluster and daemon for this orchestrator are both stopped, False otherwise.

log_file: str property readonly

Path of the daemon log file.

Returns:

Type Description
str

Path of the daemon log file.

pipeline_directory: str property readonly

Returns path to a directory in which the kubeflow pipeline files are stored.

Returns:

Type Description
str

Path to the pipeline directory.

root_directory: str property readonly

Returns path to the root directory for all files concerning this orchestrator.

Returns:

Type Description
str

Path to the root directory.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Also check that requirements are met for local components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

deprovision(self)

Deprovisions a local Kubeflow Pipelines deployment.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def deprovision(self) -> None:
    """Deprovisions a local Kubeflow Pipelines deployment."""
    if self.skip_cluster_provisioning:
        return

    if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
        local_deployment_utils.stop_kfp_ui_daemon(
            pid_file_path=self._pid_file_path
        )

    if self.is_local:
        # don't deprovision any resources if using a remote KFP installation
        local_deployment_utils.delete_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )

        logger.info("Local kubeflow pipelines deployment deprovisioned.")

    if fileio.exists(self.log_file):
        fileio.remove(self.log_file)
get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The full docker image name including registry and tag.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The full docker image name including registry and tag.
    """
    base_image_name = f"zenml-kubeflow:{pipeline_name}"
    container_registry = Repository().active_stack.container_registry

    if container_registry:
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/{base_image_name}"
    else:
        return base_image_name
get_kubernetes_contexts(self)

Get the list of configured Kubernetes contexts and the active context.

Returns:

Type Description
Tuple[List[str], Optional[str]]

A tuple containing the list of configured Kubernetes contexts and the active context.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
    """Get the list of configured Kubernetes contexts and the active context.

    Returns:
        A tuple containing the list of configured Kubernetes contexts and
        the active context.
    """
    try:
        contexts, active_context = k8s_config.list_kube_config_contexts()
    except k8s_config.config_exception.ConfigException:
        return [], None

    context_names = [c["name"] for c in contexts]
    active_context_name = active_context["name"]
    return context_names, active_context_name
list_manual_setup_steps(self, container_registry_name, container_registry_path)

Logs manual steps needed to setup the Kubeflow local orchestrator.

Parameters:

Name Type Description Default
container_registry_name str

Name of the container registry.

required
container_registry_path str

Path to the container registry.

required
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def list_manual_setup_steps(
    self, container_registry_name: str, container_registry_path: str
) -> None:
    """Logs manual steps needed to setup the Kubeflow local orchestrator.

    Args:
        container_registry_name: Name of the container registry.
        container_registry_path: Path to the container registry.
    """
    if not self.is_local:
        # Make sure we're not telling users to deploy Kubeflow on their
        # remote clusters
        logger.warning(
            "This Kubeflow orchestrator is configured to use a non-local "
            f"Kubernetes context {self.kubernetes_context}. Manually "
            f"deploying Kubeflow Pipelines is only possible for local "
            f"Kubeflow orchestrators."
        )
        return

    global_config_dir_path = io_utils.get_global_config_directory()
    kubeflow_commands = [
        f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
        f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
        f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
        f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
        f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
    ]

    logger.info(
        "If you wish to spin up this Kubeflow local orchestrator manually, "
        "please enter the following commands:\n"
    )
    logger.info("\n".join(kubeflow_commands))
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates a kfp yaml file.

This functions as an intermediary representation of the pipeline which is then deployed to the kubeflow pipelines instance.

How it works:

Before this method is called the prepare_pipeline_deployment() method builds a docker image that contains the code for the pipeline, all steps the context around these files.

Based on this docker image a callable is created which builds container_ops for each step (_construct_kfp_pipeline). To do this the entrypoint of the docker image is configured to run the correct step within the docker image. The dependencies between these container_ops are then also configured onto each container_op by pointing at the downstream steps.

This callable is then compiled into a kfp yaml file that is used as the intermediary representation of the kubeflow pipeline.

This file, together with some metadata, runtime configurations is then uploaded into the kubeflow pipelines cluster for execution.

Parameters:

Name Type Description Default
sorted_steps List[BaseStep]

A list of steps sorted by their order in the pipeline.

required
pipeline BasePipeline

The pipeline object.

required
pb2_pipeline Pipeline

The pipeline object in protobuf format.

required
stack Stack

The stack object.

required
runtime_configuration RuntimeConfiguration

The runtime configuration object.

required

Exceptions:

Type Description
RuntimeError

If you try to run the pipelines in a notebook environment.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List["BaseStep"],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates a kfp yaml file.

    This functions as an intermediary representation of the pipeline which
    is then deployed to the kubeflow pipelines instance.

    How it works:
    -------------
    Before this method is called the `prepare_pipeline_deployment()`
    method builds a docker image that contains the code for the
    pipeline, all steps the context around these files.

    Based on this docker image a callable is created which builds
    container_ops for each step (`_construct_kfp_pipeline`).
    To do this the entrypoint of the docker image is configured to
    run the correct step within the docker image. The dependencies
    between these container_ops are then also configured onto each
    container_op by pointing at the downstream steps.

    This callable is then compiled into a kfp yaml file that is used as
    the intermediary representation of the kubeflow pipeline.

    This file, together with some metadata, runtime configurations is
    then uploaded into the kubeflow pipelines cluster for execution.

    Args:
        sorted_steps: A list of steps sorted by their order in the
            pipeline.
        pipeline: The pipeline object.
        pb2_pipeline: The pipeline object in protobuf format.
        stack: The stack object.
        runtime_configuration: The runtime configuration object.

    Raises:
        RuntimeError: If you try to run the pipelines in a notebook environment.
    """
    # First check whether the code running in a notebook
    if Environment.in_notebook():
        raise RuntimeError(
            "The Kubeflow orchestrator cannot run pipelines in a notebook "
            "environment. The reason is that it is non-trivial to create "
            "a Docker image of a notebook. Please consider refactoring "
            "your notebook cells into separate scripts in a Python module "
            "and run the code outside of a notebook when using this "
            "orchestrator."
        )

    image_name = self.get_docker_image_name(pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    # Create a callable for future compilation into a dsl.Pipeline.
    def _construct_kfp_pipeline() -> None:
        """Create a container_op for each step.

        This should contain the name of the docker image and configures the
        entrypoint of the docker image to run the step.

        Additionally, this gives each container_op information about its
        direct downstream steps.

        If this callable is passed to the `_create_and_write_workflow()`
        method of a KFPCompiler all dsl.ContainerOp instances will be
        automatically added to a singular dsl.Pipeline instance.
        """
        # Dictionary of container_ops index by the associated step name
        step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

        for step in sorted_steps:
            # The command will be needed to eventually call the python step
            # within the docker container
            command = (
                KubeflowEntrypointConfiguration.get_entrypoint_command()
            )

            # The arguments are passed to configure the entrypoint of the
            # docker container when the step is called.
            metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
            arguments = (
                KubeflowEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                    **{METADATA_UI_PATH_OPTION: metadata_ui_path},
                )
            )

            # Create a container_op - the kubeflow equivalent of a step. It
            # contains the name of the step, the name of the docker image,
            # the command to use to run the step entrypoint
            # (e.g. `python -m zenml.entrypoints.step_entrypoint`)
            # and the arguments to be passed along with the command. Find
            # out more about how these arguments are parsed and used
            # in the base entrypoint `run()` method.
            container_op = dsl.ContainerOp(
                name=step.name,
                image=image_name,
                command=command,
                arguments=arguments,
                output_artifact_paths={
                    "mlpipeline-ui-metadata": metadata_ui_path,
                },
            )

            # Mounts persistent volumes, configmaps and adds labels to the
            # container op
            self._configure_container_op(container_op=container_op)

            if self.requires_resources_in_orchestration_environment(step):
                self._configure_container_resources(
                    container_op=container_op,
                    resource_configuration=step.resource_configuration,
                )

            # Find the upstream container ops of the current step and
            # configure the current container op to run after them
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                upstream_container_op = step_name_to_container_op[
                    upstream_step_name
                ]
                container_op.after(upstream_container_op)

            # Update dictionary of container ops with the current one
            step_name_to_container_op[step.name] = container_op

    # Get a filepath to use to save the finished yaml to
    assert runtime_configuration.run_name
    fileio.makedirs(self.pipeline_directory)
    pipeline_file_path = os.path.join(
        self.pipeline_directory, f"{runtime_configuration.run_name}.yaml"
    )

    # write the argo pipeline yaml
    KFPCompiler()._create_and_write_workflow(
        pipeline_func=_construct_kfp_pipeline,
        pipeline_name=pipeline.name,
        package_path=pipeline_file_path,
    )

    # using the kfp client uploads the pipeline to kubeflow pipelines and
    # runs it there
    self._upload_and_run_pipeline(
        pipeline_name=pipeline.name,
        pipeline_file_path=pipeline_file_path,
        runtime_configuration=runtime_configuration,
        enable_cache=pipeline.enable_cache,
    )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Builds a docker image for the current environment.

This function also uploads it to a container registry if configured.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to be deployed.

required
stack Stack

The stack to be deployed.

required
runtime_configuration RuntimeConfiguration

The runtime configuration to be used.

required
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Builds a docker image for the current environment.

    This function also uploads it to a container registry if configured.

    Args:
        pipeline: The pipeline to be deployed.
        stack: The stack to be deployed.
        runtime_configuration: The runtime configuration to be used.
    """
    from zenml.utils import docker_utils

    image_name = self.get_docker_image_name(pipeline.name)

    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug("Kubeflow docker container requirements: %s", requirements)

    docker_utils.build_docker_image(
        build_context_path=get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
        environment_vars=self._get_environment_vars_from_secrets(
            pipeline.secrets
        ),
    )

    assert stack.container_registry  # should never happen due to validation
    stack.container_registry.push_image(image_name)

    # Store the docker image digest in the runtime configuration so it gets
    # tracked in the ZenStore
    image_digest = docker_utils.get_image_digest(image_name) or image_name
    runtime_configuration["docker_image"] = image_digest
provision(self)

Provisions a local Kubeflow Pipelines deployment.

Exceptions:

Type Description
ProvisioningError

If the provisioning fails.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def provision(self) -> None:
    """Provisions a local Kubeflow Pipelines deployment.

    Raises:
        ProvisioningError: If the provisioning fails.
    """
    if self.skip_cluster_provisioning:
        return

    if self.is_running:
        logger.info(
            "Found already existing local Kubeflow Pipelines deployment. "
            "If there are any issues with the existing deployment, please "
            "run 'zenml stack down --force' to delete it."
        )
        return

    if not local_deployment_utils.check_prerequisites():
        raise ProvisioningError(
            "Unable to provision local Kubeflow Pipelines deployment: "
            "Please install 'k3d' and 'kubectl' and try again."
        )

    container_registry = Repository().active_stack.container_registry

    # should not happen, because the stack validation takes care of this,
    # but just in case
    assert container_registry is not None

    fileio.makedirs(self.root_directory)

    if not self.is_local:
        # don't provision any resources if using a remote KFP installation
        return

    logger.info("Provisioning local Kubeflow Pipelines deployment...")

    container_registry_port = int(container_registry.uri.split(":")[-1])
    container_registry_name = self._get_k3d_registry_name(
        port=container_registry_port
    )
    local_deployment_utils.write_local_registry_yaml(
        yaml_path=self._k3d_registry_config_path,
        registry_name=container_registry_name,
        registry_uri=container_registry.uri,
    )

    try:
        local_deployment_utils.create_k3d_cluster(
            cluster_name=self._k3d_cluster_name,
            registry_name=container_registry_name,
            registry_config_path=self._k3d_registry_config_path,
        )
        kubernetes_context = self.kubernetes_context

        # will never happen, but mypy doesn't know that
        assert kubernetes_context is not None

        local_deployment_utils.deploy_kubeflow_pipelines(
            kubernetes_context=kubernetes_context
        )

        artifact_store = Repository().active_stack.artifact_store
        if isinstance(artifact_store, LocalArtifactStore):
            local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
                kubernetes_context=kubernetes_context,
                local_path=artifact_store.path,
            )
    except Exception as e:
        logger.error(e)
        logger.error(
            "Unable to spin up local Kubeflow Pipelines deployment."
        )

        self.list_manual_setup_steps(
            container_registry_name, self._k3d_registry_config_path
        )
        self.deprovision()
resume(self)

Resumes the local k3d cluster.

Exceptions:

Type Description
ProvisioningError

If the k3d cluster is not provisioned.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def resume(self) -> None:
    """Resumes the local k3d cluster.

    Raises:
        ProvisioningError: If the k3d cluster is not provisioned.
    """
    if self.is_running:
        logger.info("Local kubeflow pipelines deployment already running.")
        return

    if not self.is_provisioned:
        raise ProvisioningError(
            "Unable to resume local kubeflow pipelines deployment: No "
            "resources provisioned for local deployment."
        )

    kubernetes_context = self.kubernetes_context

    # will never happen, but mypy doesn't know that
    assert kubernetes_context is not None

    if (
        not self.skip_cluster_provisioning
        and self.is_local
        and not self.is_cluster_running
    ):
        # don't resume any resources if using a remote KFP installation
        local_deployment_utils.start_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )

        local_deployment_utils.wait_until_kubeflow_pipelines_ready(
            kubernetes_context=kubernetes_context
        )

    if not self.is_daemon_running:
        local_deployment_utils.start_kfp_ui_daemon(
            pid_file_path=self._pid_file_path,
            log_file_path=self.log_file,
            port=self._get_kfp_ui_daemon_port(),
            kubernetes_context=kubernetes_context,
        )
set_default_kubernetes_context(values) classmethod

Pydantic root_validator.

This sets the default kubernetes_context value to the value that is used to create the locally managed k3d cluster, if not explicitly set.

Parameters:

Name Type Description Default
values Dict[str, Any]

Values passed to the object constructor

required

Returns:

Type Description
Dict[str, Any]

Values passed to the Pydantic constructor

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_default_kubernetes_context(
    cls, values: Dict[str, Any]
) -> Dict[str, Any]:
    """Pydantic root_validator.

    This sets the default `kubernetes_context` value to the value that is
    used to create the locally managed k3d cluster, if not explicitly set.

    Args:
        values: Values passed to the object constructor

    Returns:
        Values passed to the Pydantic constructor
    """
    if not values.get("kubernetes_context"):
        # not likely, due to Pydantic validation, but mypy complains
        assert "uuid" in values
        values["kubernetes_context"] = cls._get_k3d_kubernetes_context(
            values["uuid"]
        )

    return values
suspend(self)

Suspends the local k3d cluster.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def suspend(self) -> None:
    """Suspends the local k3d cluster."""
    if not self.is_provisioned:
        logger.info("Local kubeflow pipelines deployment not provisioned.")
        return

    if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
        local_deployment_utils.stop_kfp_ui_daemon(
            pid_file_path=self._pid_file_path
        )

    if (
        not self.skip_cluster_provisioning
        and self.is_local
        and self.is_cluster_running
    ):
        # don't suspend any resources if using a remote KFP installation
        local_deployment_utils.stop_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )
local_deployment_utils

Utils for the local Kubeflow deployment behaviors.

add_hostpath_to_kubeflow_pipelines(kubernetes_context, local_path)

Patches the Kubeflow Pipelines deployment to mount a local folder.

This folder serves as a hostpath for visualization purposes.

This function reconfigures the Kubeflow pipelines deployment to use a shared local folder to support loading the TensorBoard viewer and other pipeline visualization results from a local artifact store, as described here:

https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context on which Kubeflow Pipelines should be patched.

required
local_path str

The path to the local folder to mount as a hostpath.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def add_hostpath_to_kubeflow_pipelines(
    kubernetes_context: str, local_path: str
) -> None:
    """Patches the Kubeflow Pipelines deployment to mount a local folder.

    This folder serves as a hostpath for visualization purposes.

    This function reconfigures the Kubeflow pipelines deployment to use a
    shared local folder to support loading the TensorBoard viewer and other
    pipeline visualization results from a local artifact store, as described
    here:

    https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md

    Args:
        kubernetes_context: The kubernetes context on which Kubeflow Pipelines
            should be patched.
        local_path: The path to the local folder to mount as a hostpath.
    """
    logger.info("Patching Kubeflow Pipelines to mount a local folder.")

    pod_template = {
        "spec": {
            "serviceAccountName": "kubeflow-pipelines-viewer",
            "containers": [
                {
                    "volumeMounts": [
                        {
                            "mountPath": local_path,
                            "name": "local-artifact-store",
                        }
                    ]
                }
            ],
            "volumes": [
                {
                    "hostPath": {
                        "path": local_path,
                        "type": "Directory",
                    },
                    "name": "local-artifact-store",
                }
            ],
        }
    }
    pod_template_json = json.dumps(pod_template, indent=2)
    config_map_data = {"data": {"viewer-pod-template.json": pod_template_json}}
    config_map_data_json = json.dumps(config_map_data, indent=2)

    logger.debug(
        "Adding host path volume for local path `%s` to kubeflow pipeline"
        "viewer pod template configuration.",
        local_path,
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "-n",
            "kubeflow",
            "patch",
            "configmap/ml-pipeline-ui-configmap",
            "--type",
            "merge",
            "-p",
            config_map_data_json,
        ]
    )

    deployment_patch = {
        "spec": {
            "template": {
                "spec": {
                    "containers": [
                        {
                            "name": "ml-pipeline-ui",
                            "volumeMounts": [
                                {
                                    "mountPath": local_path,
                                    "name": "local-artifact-store",
                                }
                            ],
                        }
                    ],
                    "volumes": [
                        {
                            "hostPath": {
                                "path": local_path,
                                "type": "Directory",
                            },
                            "name": "local-artifact-store",
                        }
                    ],
                }
            }
        }
    }
    deployment_patch_json = json.dumps(deployment_patch, indent=2)

    logger.debug(
        "Adding host path volume for local path `%s` to the kubeflow UI",
        local_path,
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "-n",
            "kubeflow",
            "patch",
            "deployment/ml-pipeline-ui",
            "--type",
            "strategic",
            "-p",
            deployment_patch_json,
        ]
    )
    wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)

    logger.info("Finished patching Kubeflow Pipelines setup.")
check_prerequisites(skip_k3d=False, skip_kubectl=False)

Checks prerequisites for a local kubeflow pipelines deployment.

It makes sure they are installed.

Parameters:

Name Type Description Default
skip_k3d bool

Whether to skip the check for the k3d command.

False
skip_kubectl bool

Whether to skip the check for the kubectl command.

False

Returns:

Type Description
bool

Whether all prerequisites are installed.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def check_prerequisites(
    skip_k3d: bool = False, skip_kubectl: bool = False
) -> bool:
    """Checks prerequisites for a local kubeflow pipelines deployment.

    It makes sure they are installed.

    Args:
        skip_k3d: Whether to skip the check for the k3d command.
        skip_kubectl: Whether to skip the check for the kubectl command.

    Returns:
        Whether all prerequisites are installed.
    """
    k3d_installed = skip_k3d or shutil.which("k3d") is not None
    kubectl_installed = skip_kubectl or shutil.which("kubectl") is not None
    logger.debug(
        "Local kubeflow deployment prerequisites: K3D - %s, Kubectl - %s",
        k3d_installed,
        kubectl_installed,
    )
    return k3d_installed and kubectl_installed
create_k3d_cluster(cluster_name, registry_name, registry_config_path)

Creates a K3D cluster.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to create.

required
registry_name str

Name of the registry to create for this cluster.

required
registry_config_path str

Path to the registry config file.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def create_k3d_cluster(
    cluster_name: str, registry_name: str, registry_config_path: str
) -> None:
    """Creates a K3D cluster.

    Args:
        cluster_name: Name of the cluster to create.
        registry_name: Name of the registry to create for this cluster.
        registry_config_path: Path to the registry config file.
    """
    logger.info("Creating local K3D cluster '%s'.", cluster_name)
    global_config_dir_path = io_utils.get_global_config_directory()
    subprocess.check_call(
        [
            "k3d",
            "cluster",
            "create",
            cluster_name,
            "--image",
            K3S_IMAGE_NAME,
            "--registry-create",
            registry_name,
            "--registry-config",
            registry_config_path,
            "--volume",
            f"{global_config_dir_path}:{global_config_dir_path}",
        ]
    )
    logger.info("Finished K3D cluster creation.")
delete_k3d_cluster(cluster_name)

Deletes a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to delete.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def delete_k3d_cluster(cluster_name: str) -> None:
    """Deletes a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to delete.
    """
    subprocess.check_call(["k3d", "cluster", "delete", cluster_name])
    logger.info("Deleted local k3d cluster '%s'.", cluster_name)
deploy_kubeflow_pipelines(kubernetes_context)

Deploys Kubeflow Pipelines.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context on which Kubeflow Pipelines should be deployed.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def deploy_kubeflow_pipelines(kubernetes_context: str) -> None:
    """Deploys Kubeflow Pipelines.

    Args:
        kubernetes_context: The kubernetes context on which Kubeflow Pipelines
            should be deployed.
    """
    logger.info("Deploying Kubeflow Pipelines.")
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "apply",
            "-k",
            f"github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
        ]
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "wait",
            "--timeout=60s",
            "--for",
            "condition=established",
            "crd/applications.app.k8s.io",
        ]
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "apply",
            "-k",
            f"github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
        ]
    )

    wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
    logger.info("Finished Kubeflow Pipelines setup.")
k3d_cluster_exists(cluster_name)

Checks whether there exists a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to check.

required

Returns:

Type Description
bool

Whether the cluster exists.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_exists(cluster_name: str) -> bool:
    """Checks whether there exists a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to check.

    Returns:
        Whether the cluster exists.
    """
    output = subprocess.check_output(
        ["k3d", "cluster", "list", "--output", "json"]
    )
    clusters = json.loads(output)
    for cluster in clusters:
        if cluster["name"] == cluster_name:
            return True
    return False
k3d_cluster_running(cluster_name)

Checks whether the K3D cluster with the given name is running.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to check.

required

Returns:

Type Description
bool

Whether the cluster is running.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_running(cluster_name: str) -> bool:
    """Checks whether the K3D cluster with the given name is running.

    Args:
        cluster_name: Name of the cluster to check.

    Returns:
        Whether the cluster is running.
    """
    output = subprocess.check_output(
        ["k3d", "cluster", "list", "--output", "json"]
    )
    clusters = json.loads(output)
    for cluster in clusters:
        if cluster["name"] == cluster_name:
            server_count: int = cluster["serversCount"]
            servers_running: int = cluster["serversRunning"]
            return servers_running == server_count
    return False
kubeflow_pipelines_ready(kubernetes_context)

Returns whether all Kubeflow Pipelines pods are ready.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context in which the pods should be checked.

required

Returns:

Type Description
bool

Whether all Kubeflow Pipelines pods are ready.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def kubeflow_pipelines_ready(kubernetes_context: str) -> bool:
    """Returns whether all Kubeflow Pipelines pods are ready.

    Args:
        kubernetes_context: The kubernetes context in which the pods
            should be checked.

    Returns:
        Whether all Kubeflow Pipelines pods are ready.
    """
    try:
        subprocess.check_call(
            [
                "kubectl",
                "--context",
                kubernetes_context,
                "--namespace",
                "kubeflow",
                "wait",
                "--for",
                "condition=ready",
                "--timeout=0s",
                "pods",
                "-l",
                "application-crd-id=kubeflow-pipelines",
            ],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        return True
    except subprocess.CalledProcessError:
        return False
start_k3d_cluster(cluster_name)

Starts a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to start.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_k3d_cluster(cluster_name: str) -> None:
    """Starts a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to start.
    """
    subprocess.check_call(["k3d", "cluster", "start", cluster_name])
    logger.info("Started local k3d cluster '%s'.", cluster_name)
start_kfp_ui_daemon(pid_file_path, log_file_path, port, kubernetes_context)

Starts a daemon process that forwards ports.

This is so the Kubeflow Pipelines UI is accessible in the browser.

Parameters:

Name Type Description Default
pid_file_path str

Path where the file with the daemons process ID should be written.

required
log_file_path str

Path to a file where the daemon logs should be written.

required
port int

Port on which the UI should be accessible.

required
kubernetes_context str

The kubernetes context for the cluster where Kubeflow Pipelines is running.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_kfp_ui_daemon(
    pid_file_path: str,
    log_file_path: str,
    port: int,
    kubernetes_context: str,
) -> None:
    """Starts a daemon process that forwards ports.

    This is so the Kubeflow Pipelines UI is accessible in the browser.

    Args:
        pid_file_path: Path where the file with the daemons process ID should
            be written.
        log_file_path: Path to a file where the daemon logs should be written.
        port: Port on which the UI should be accessible.
        kubernetes_context: The kubernetes context for the cluster where
            Kubeflow Pipelines is running.
    """
    command = [
        "kubectl",
        "--context",
        kubernetes_context,
        "--namespace",
        "kubeflow",
        "port-forward",
        "svc/ml-pipeline-ui",
        f"{port}:80",
    ]

    if not networking_utils.port_available(port):
        modified_command = command.copy()
        modified_command[-1] = "PORT:80"
        logger.warning(
            "Unable to port-forward Kubeflow Pipelines UI to local port %d "
            "because the port is occupied. In order to access the Kubeflow "
            "Pipelines UI at http://localhost:PORT/, please run '%s' in a "
            "separate command line shell (replace PORT with a free port of "
            "your choice).",
            port,
            " ".join(modified_command),
        )
    elif sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Kubeflow Pipelines UI at "
            "http://localhost:%d/, please run '%s' in a separate command "
            "line shell.",
            port,
            " ".join(command),
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Port-forwards the Kubeflow Pipelines UI pod."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function, pid_file=pid_file_path, log_file=log_file_path
        )
        logger.info(
            "Started Kubeflow Pipelines UI daemon (check the daemon logs at %s "
            "in case you're not able to view the UI). The Kubeflow Pipelines "
            "UI should now be accessible at http://localhost:%d/.",
            log_file_path,
            port,
        )
stop_k3d_cluster(cluster_name)

Stops a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to stop.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_k3d_cluster(cluster_name: str) -> None:
    """Stops a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to stop.
    """
    subprocess.check_call(["k3d", "cluster", "stop", cluster_name])
    logger.info("Stopped local k3d cluster '%s'.", cluster_name)
stop_kfp_ui_daemon(pid_file_path)

Stops the KFP UI daemon process if it is running.

Parameters:

Name Type Description Default
pid_file_path str

Path to the file with the daemons process ID.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_kfp_ui_daemon(pid_file_path: str) -> None:
    """Stops the KFP UI daemon process if it is running.

    Args:
        pid_file_path: Path to the file with the daemons process ID.
    """
    if fileio.exists(pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(pid_file_path)
            fileio.remove(pid_file_path)
            logger.info("Stopped Kubeflow Pipelines UI daemon.")
wait_until_kubeflow_pipelines_ready(kubernetes_context)

Waits until all Kubeflow Pipelines pods are ready.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context in which the pods should be checked.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def wait_until_kubeflow_pipelines_ready(kubernetes_context: str) -> None:
    """Waits until all Kubeflow Pipelines pods are ready.

    Args:
        kubernetes_context: The kubernetes context in which the pods
            should be checked.
    """
    logger.info(
        "Waiting for all Kubeflow Pipelines pods to be ready (this might "
        "take a few minutes)."
    )
    while True:
        logger.info("Current pod status:")
        subprocess.check_call(
            [
                "kubectl",
                "--context",
                kubernetes_context,
                "--namespace",
                "kubeflow",
                "get",
                "pods",
            ]
        )
        if kubeflow_pipelines_ready(kubernetes_context=kubernetes_context):
            break

        logger.info("One or more pods not ready yet, waiting for 30 seconds...")
        time.sleep(30)
write_local_registry_yaml(yaml_path, registry_name, registry_uri)

Writes a K3D registry config file.

Parameters:

Name Type Description Default
yaml_path str

Path where the config file should be written to.

required
registry_name str

Name of the registry.

required
registry_uri str

URI of the registry.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def write_local_registry_yaml(
    yaml_path: str, registry_name: str, registry_uri: str
) -> None:
    """Writes a K3D registry config file.

    Args:
        yaml_path: Path where the config file should be written to.
        registry_name: Name of the registry.
        registry_uri: URI of the registry.
    """
    yaml_content = {
        "mirrors": {registry_uri: {"endpoint": [f"http://{registry_name}"]}}
    }
    yaml_utils.write_yaml(yaml_path, yaml_content)
utils

Utils for ZenML Kubeflow orchestrators implementation.

dump_ui_metadata(node, execution_info, metadata_ui_path)

Dump KFP UI metadata json file for visualization purpose.

For general components we just render a simple Markdown file for exec_properties/inputs/outputs.

Parameters:

Name Type Description Default
node PipelineNode

associated TFX node.

required
execution_info ExecutionInfo

runtime execution info for this component, including materialized inputs/outputs/execution properties and id.

required
metadata_ui_path str

path to dump ui metadata.

required
Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def dump_ui_metadata(
    node: PipelineNode,
    execution_info: data_types.ExecutionInfo,
    metadata_ui_path: str,
) -> None:
    """Dump KFP UI metadata json file for visualization purpose.

    For general components we just render a simple Markdown file for
        exec_properties/inputs/outputs.

    Args:
        node: associated TFX node.
        execution_info: runtime execution info for this component, including
            materialized inputs/outputs/execution properties and id.
        metadata_ui_path: path to dump ui metadata.
    """
    exec_properties_list = [
        "**{}**: {}".format(
            _sanitize_underscore(name), _sanitize_underscore(exec_property)
        )
        for name, exec_property in execution_info.exec_properties.items()
    ]
    src_str_exec_properties = "# Execution properties:\n{}".format(
        "\n\n".join(exec_properties_list) or "No execution property."
    )

    def _dump_input_populated_artifacts(
        node_inputs: MutableMapping[str, InputSpec],
        name_to_artifacts: Dict[str, List[artifact.Artifact]],
    ) -> List[str]:
        """Dump artifacts markdown string for inputs.

        Args:
            node_inputs: maps from input name to input sepc proto.
            name_to_artifacts: maps from input key to list of populated    artifacts.

        Returns:
            A list of dumped markdown string, each of which represents a channel.
        """
        rendered_list = []
        for name, spec in node_inputs.items():
            # Need to look for materialized artifacts in the execution decision.
            rendered_artifacts = "".join(
                [
                    _render_artifact_as_mdstr(single_artifact)
                    for single_artifact in name_to_artifacts.get(name, [])
                ]
            )
            # There must be at least a channel in a input, and all channels in
            # a input share the same artifact type.
            artifact_type = spec.channels[0].artifact_query.type.name
            rendered_list.append(
                "## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
                    name=_sanitize_underscore(name),
                    channel_type=_sanitize_underscore(artifact_type),
                    artifacts=rendered_artifacts,
                )
            )

        return rendered_list

    def _dump_output_populated_artifacts(
        node_outputs: MutableMapping[str, OutputSpec],
        name_to_artifacts: Dict[str, List[artifact.Artifact]],
    ) -> List[str]:
        """Dump artifacts markdown string for outputs.

        Args:
            node_outputs: maps from output name to output sepc proto.
            name_to_artifacts: maps from output key to list of populated
                artifacts.

        Returns:
            A list of dumped markdown string, each of which represents a channel.
        """
        rendered_list = []
        for name, spec in node_outputs.items():
            # Need to look for materialized artifacts in the execution decision.
            rendered_artifacts = "".join(
                [
                    _render_artifact_as_mdstr(single_artifact)
                    for single_artifact in name_to_artifacts.get(name, [])
                ]
            )
            # There must be at least a channel in a input, and all channels
            # in a input share the same artifact type.
            artifact_type = spec.artifact_spec.type.name
            rendered_list.append(
                "## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
                    name=_sanitize_underscore(name),
                    channel_type=_sanitize_underscore(artifact_type),
                    artifacts=rendered_artifacts,
                )
            )

        return rendered_list

    src_str_inputs = "# Inputs:\n{}".format(
        "".join(
            _dump_input_populated_artifacts(
                node_inputs=node.inputs.inputs,
                name_to_artifacts=execution_info.input_dict or {},
            )
        )
        or "No input."
    )

    src_str_outputs = "# Outputs:\n{}".format(
        "".join(
            _dump_output_populated_artifacts(
                node_outputs=node.outputs.outputs,
                name_to_artifacts=execution_info.output_dict or {},
            )
        )
        or "No output."
    )

    outputs = [
        {
            "storage": "inline",
            "source": "{exec_properties}\n\n{inputs}\n\n{outputs}".format(
                exec_properties=src_str_exec_properties,
                inputs=src_str_inputs,
                outputs=src_str_outputs,
            ),
            "type": "markdown",
        }
    ]
    # Add TensorBoard view for ModelRun outputs.
    for name, spec in node.outputs.outputs.items():
        if (
            spec.artifact_spec.type.name
            == standard_artifacts.ModelRun.TYPE_NAME
            or spec.artifact_spec.type.name == ModelArtifact.TYPE_NAME
        ):
            output_model = execution_info.output_dict[name][0]
            source = output_model.uri

            # For local artifact repository, use a path that is relative to
            # the point where the local artifact folder is mounted as a volume
            artifact_store = Repository().active_stack.artifact_store
            if isinstance(artifact_store, LocalArtifactStore):
                source = os.path.relpath(source, artifact_store.path)
                source = f"volume://local-artifact-store/{source}"
            # Add TensorBoard view.
            tensorboard_output = {
                "type": "tensorboard",
                "source": source,
            }
            outputs.append(tensorboard_output)

    metadata_dict = {"outputs": outputs}

    with open(metadata_ui_path, "w") as f:
        json.dump(metadata_dict, f)
mount_config_map_op(config_map_name)

Mounts all key-value pairs found in the named Kubernetes ConfigMap.

All key-value pairs in the ConfigMap are mounted as environment variables.

Parameters:

Name Type Description Default
config_map_name str

The name of the ConfigMap resource.

required

Returns:

Type Description
Callable[[kfp.dsl._container_op.ContainerOp], NoneType]

An OpFunc for mounting the ConfigMap.

Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def mount_config_map_op(
    config_map_name: str,
) -> Callable[[dsl.ContainerOp], None]:
    """Mounts all key-value pairs found in the named Kubernetes ConfigMap.

    All key-value pairs in the ConfigMap are mounted as environment variables.

    Args:
        config_map_name: The name of the ConfigMap resource.

    Returns:
        An OpFunc for mounting the ConfigMap.
    """

    def mount_config_map(container_op: dsl.ContainerOp) -> None:
        """Mounts all key-value pairs found in the Kubernetes ConfigMap.

        Args:
            container_op: The container op to mount the ConfigMap.
        """
        config_map_ref = k8s_client.V1ConfigMapEnvSource(
            name=config_map_name, optional=True
        )
        container_op.container.add_env_from(
            k8s_client.V1EnvFromSource(config_map_ref=config_map_ref)
        )

    return mount_config_map

kubernetes special

Kubernetes integration for Kubernetes-native orchestration.

The Kubernetes integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubernetes orchestrator with the CLI tool.

KubernetesIntegration (Integration)

Definition of Kubernetes integration for ZenML.

Source code in zenml/integrations/kubernetes/__init__.py
class KubernetesIntegration(Integration):
    """Definition of Kubernetes integration for ZenML."""

    NAME = KUBERNETES
    REQUIREMENTS = ["kubernetes==18.20.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Kubernetes integration.

        Returns:
            List of new stack component flavors.
        """
        return [
            FlavorWrapper(
                name=KUBERNETES_METADATA_STORE_FLAVOR,
                source="zenml.integrations.kubernetes.metadata_stores.KubernetesMetadataStore",
                type=StackComponentType.METADATA_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=KUBERNETES_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.kubernetes.orchestrators.KubernetesOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the Kubernetes integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of new stack component flavors.

Source code in zenml/integrations/kubernetes/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Kubernetes integration.

    Returns:
        List of new stack component flavors.
    """
    return [
        FlavorWrapper(
            name=KUBERNETES_METADATA_STORE_FLAVOR,
            source="zenml.integrations.kubernetes.metadata_stores.KubernetesMetadataStore",
            type=StackComponentType.METADATA_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=KUBERNETES_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.kubernetes.orchestrators.KubernetesOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
    ]

metadata_stores special

Initialization of the Kubernetes metadata store for ZenML.

kubernetes_metadata_store

Implementation of Kubernetes metadata store.

KubernetesMetadataStore (BaseMetadataStore) pydantic-model

Kubernetes metadata store (MySQL database deployed in the cluster).

Attributes:

Name Type Description
deployment_name str

Name of the Kubernetes deployment and corresponding service/pod that will be created when calling provision().

kubernetes_context str

Name of the Kubernetes context in which to deploy and provision the MySQL database.

kubernetes_namespace str

Name of the Kubernetes namespace. Defaults to "default".

storage_capacity str

Storage capacity of the metadata store. Defaults to "10Gi" (=10GB).

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
class KubernetesMetadataStore(BaseMetadataStore):
    """Kubernetes metadata store (MySQL database deployed in the cluster).

    Attributes:
        deployment_name: Name of the Kubernetes deployment and corresponding
            service/pod that will be created when calling `provision()`.
        kubernetes_context: Name of the Kubernetes context in which to deploy
            and provision the MySQL database.
        kubernetes_namespace: Name of the Kubernetes namespace.
            Defaults to "default".
        storage_capacity: Storage capacity of the metadata store.
            Defaults to `"10Gi"` (=10GB).
    """

    deployment_name: str
    kubernetes_context: str
    kubernetes_namespace: str = "zenml"
    storage_capacity: str = "10Gi"
    _k8s_core_api: k8s_client.CoreV1Api = None
    _k8s_apps_api: k8s_client.AppsV1Api = None

    FLAVOR: ClassVar[str] = KUBERNETES_METADATA_STORE_FLAVOR

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initiate the Pydantic object and initialize the Kubernetes clients.

        Args:
            *args: The positional arguments to pass to the Pydantic object.
            **kwargs: The keyword arguments to pass to the Pydantic object.
        """
        super().__init__(*args, **kwargs)
        self._initialize_k8s_clients()

    def _initialize_k8s_clients(self) -> None:
        """Initialize the Kubernetes clients."""
        kube_utils.load_kube_config(context=self.kubernetes_context)
        self._k8s_core_api = k8s_client.CoreV1Api()
        self._k8s_apps_api = k8s_client.AppsV1Api()

    @root_validator(skip_on_failure=False)
    def check_required_attributes(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Pydantic root_validator.

        This ensures that both `deployment_name` and `kubernetes_context` are
        set and raises an error with a custom error message otherwise.

        Args:
            values: Values passed to the Pydantic constructor.

        Raises:
            StackComponentInterfaceError: if either `deployment_name` or
                `kubernetes_context` is not defined.

        Returns:
            Values passed to the Pydantic constructor.
        """
        usage_note = (
            "Note: the `kubernetes` metadata store flavor is a special "
            "subtype of the `mysql` metadata store that deploys a fresh "
            "MySQL database within your Kubernetes cluster when running "
            "`zenml stack up`. "
            "If you already have a MySQL database running in your cluster "
            "(or elsewhere), simply use the `mysql` metadata store flavor "
            "instead."
        )

        for required_field in ("deployment_name", "kubernetes_context"):
            if required_field not in values:
                raise StackComponentInterfaceError(
                    f"Required field `{required_field}` missing for "
                    "`KubernetesMetadataStore`. " + usage_note
                )

        return values

    @property
    def deployment_exists(self) -> bool:
        """Check whether a MySQL deployment exists in the cluster.

        Returns:
            Whether a MySQL deployment exists in the cluster.
        """
        resp = self._k8s_apps_api.list_namespaced_deployment(
            namespace=self.kubernetes_namespace
        )
        for i in resp.items:
            if i.metadata.name == self.deployment_name:
                return True
        return False

    @property
    def is_provisioned(self) -> bool:
        """If the component provisioned resources to run.

        Checks whether the required MySQL deployment exists.

        Returns:
            True if the component provisioned resources to run.
        """
        return super().is_provisioned and self.deployment_exists

    @property
    def is_running(self) -> bool:
        """If the component is running.

        Returns:
            True if `is_provisioned` else False.
        """
        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            if not check_if_daemon_is_running(self._pid_file_path):
                return False
        else:
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass

        return self.is_provisioned

    def provision(self) -> None:
        """Provision the metadata store.

        Creates a deployment with a MySQL database running in it.
        """
        logger.info("Provisioning Kubernetes MySQL metadata store...")
        kube_utils.create_namespace(
            core_api=self._k8s_core_api, namespace=self.kubernetes_namespace
        )
        kube_utils.create_mysql_deployment(
            core_api=self._k8s_core_api,
            apps_api=self._k8s_apps_api,
            namespace=self.kubernetes_namespace,
            storage_capacity=self.storage_capacity,
            deployment_name=self.deployment_name,
        )

        # wait a bit, then make sure deployment pod is alive and running.
        logger.info("Trying to reach Kubernetes MySQL metadata store pod...")
        time.sleep(10)
        kube_utils.wait_pod(
            core_api=self._k8s_core_api,
            pod_name=self.pod_name,
            namespace=self.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_not_pending,
        )
        logger.info("Kubernetes MySQL metadata store pod is up and running.")

    def deprovision(self) -> None:
        """Deprovision the metadata store by deleting the MySQL deployment."""
        logger.info("Deleting Kubernetes MySQL metadata store...")
        self.suspend()
        kube_utils.delete_deployment(
            apps_api=self._k8s_apps_api,
            deployment_name=self.deployment_name,
            namespace=self.kubernetes_namespace,
        )

    # TODO: code duplication with kubeflow metadata store below.

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all files concerning this orchestrator.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            io_utils.get_global_config_directory(),
            self.FLAVOR,
            str(self.uuid),
        )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(
            self.root_directory, DEFAULT_KUBERNETES_METADATA_DAEMON_PID_FILE
        )

    @property
    def _log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path to the daemon log file.
        """
        return os.path.join(
            self.root_directory, DEFAULT_KUBERNETES_METADATA_DAEMON_LOG_FILE
        )

    def resume(self) -> None:
        """Resumes the metadata store."""
        self.start_metadata_daemon()
        self.wait_until_metadata_store_ready(
            timeout=DEFAULT_KUBERNETES_METADATA_DAEMON_TIMEOUT
        )

    def suspend(self) -> None:
        """Suspends the metadata store."""
        self.stop_metadata_daemon()

    @property
    def pod_name(self) -> str:
        """Name of the Kubernetes pod where the MySQL database is deployed.

        Returns:
            Name of the Kubernetes pod.
        """
        pod_list = self._k8s_core_api.list_namespaced_pod(
            namespace=self.kubernetes_namespace,
            label_selector=f"app={self.deployment_name}",
        )
        return pod_list.items[0].metadata.name  # type: ignore[no-any-return]

    @property
    def host(self) -> str:
        """Get the MySQL host required to access the metadata store.

        This overwrites the MySQL host to use local host when
        running outside of the cluster so we can access the metadata store
        locally for post execution.

        Raises:
            RuntimeError: If the metadata store is not running.

        Returns:
            MySQL host.
        """
        if kube_utils.is_inside_kubernetes():
            return DEFAULT_KUBERNETES_MYSQL_HOST
        if not self.is_running:
            raise RuntimeError(
                "The Kubernetes metadata daemon is not running. Please run the "
                "following command to start it first:\n\n"
                "    'zenml metadata-store up'\n"
            )
        return DEFAULT_KUBERNETES_MYSQL_LOCAL_HOST

    @property
    def port(self) -> int:
        """Get the MySQL port required to access the metadata store.

        Returns:
            int: MySQL port.
        """
        return DEFAULT_KUBERNETES_MYSQL_PORT

    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config for the Kubernetes metadata store.

        Returns:
            The tfx metadata config.
        """
        config = MySQLDatabaseConfig(
            host=self.host,
            port=self.port,
            database=DEFAULT_KUBERNETES_MYSQL_DATABASE,
            user=DEFAULT_KUBERNETES_MYSQL_USERNAME,
            password=DEFAULT_KUBERNETES_MYSQL_PASSWORD,
        )
        connection_config = metadata_store_pb2.ConnectionConfig(mysql=config)
        return connection_config

    def start_metadata_daemon(self) -> None:
        """Starts a daemon process that forwards ports.

        This is so the MySQL database in the Kubernetes cluster is accessible
        on the localhost.

        Raises:
            ProvisioningError: if the daemon fails to start.
        """
        command = [
            "kubectl",
            "--context",
            self.kubernetes_context,
            "--namespace",
            self.kubernetes_namespace,
            "port-forward",
            f"svc/{self.deployment_name}",
            f"{self.port}:{self.port}",
        ]
        if sys.platform == "win32":
            logger.warning(
                "Daemon functionality not supported on Windows. "
                "In order to access the Kubernetes Metadata locally, "
                "please run '%s' in a separate command line shell.",
                self.port,
                " ".join(command),
            )
        elif not networking_utils.port_available(self.port):
            raise ProvisioningError(
                f"Unable to port-forward Kubernetes Metadata to local "
                f"port {self.port} because the port is occupied. In order to "
                f"access the Kubernetes Metadata locally, please "
                f"change the metadata store configuration to use an available "
                f"port or stop the other process currently using the port."
            )
        else:
            from zenml.utils import daemon

            def _daemon_function() -> None:
                """Forwards the port of the Kubernetes metadata store pod ."""
                subprocess.check_call(command)

            daemon.run_as_daemon(
                _daemon_function,
                pid_file=self._pid_file_path,
                log_file=self._log_file,
            )
            logger.info(
                "Started Kubernetes Metadata daemon (check the daemon"
                "logs at %s in case you're not able to access the pipeline"
                "metadata).",
                self._log_file,
            )

    def stop_metadata_daemon(self) -> None:
        """Stops the Kubernetes metadata daemon process if it is running."""
        if sys.platform != "win32" and fileio.exists(self._pid_file_path):
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path)
            fileio.remove(self._pid_file_path)

    def wait_until_metadata_store_ready(self, timeout: int) -> None:
        """Waits until the metadata store connection is ready.

        Potentially an irrecoverable error could occur or the timeout could
        expire, so it checks for this.

        Args:
            timeout: The maximum time to wait for the metadata store to be
                ready.

        Raises:
            RuntimeError: if the metadata store is not ready after the timeout
        """
        logger.info(
            "Waiting for the Kubernetes metadata store to be ready (this "
            "might take a few minutes)."
        )
        while True:
            try:
                # it doesn't matter what we call here as long as it exercises
                # the MLMD connection
                self.get_pipelines()
                break
            except Exception as e:
                logger.info(
                    "The Kubernetes metadata store is not ready yet. Waiting "
                    "for 10 seconds..."
                )
                if timeout <= 0:
                    raise RuntimeError(
                        f"An unexpected error was encountered while waiting "
                        f"for the Kubernetes metadata store to be functional: "
                        f"{str(e)}"
                    ) from e
                timeout -= 10
                time.sleep(10)

        logger.info("The Kubernetes metadata store is functional.")
deployment_exists: bool property readonly

Check whether a MySQL deployment exists in the cluster.

Returns:

Type Description
bool

Whether a MySQL deployment exists in the cluster.

host: str property readonly

Get the MySQL host required to access the metadata store.

This overwrites the MySQL host to use local host when running outside of the cluster so we can access the metadata store locally for post execution.

Exceptions:

Type Description
RuntimeError

If the metadata store is not running.

Returns:

Type Description
str

MySQL host.

is_provisioned: bool property readonly

If the component provisioned resources to run.

Checks whether the required MySQL deployment exists.

Returns:

Type Description
bool

True if the component provisioned resources to run.

is_running: bool property readonly

If the component is running.

Returns:

Type Description
bool

True if is_provisioned else False.

pod_name: str property readonly

Name of the Kubernetes pod where the MySQL database is deployed.

Returns:

Type Description
str

Name of the Kubernetes pod.

port: int property readonly

Get the MySQL port required to access the metadata store.

Returns:

Type Description
int

MySQL port.

root_directory: str property readonly

Returns path to the root directory for all files concerning this orchestrator.

Returns:

Type Description
str

Path to the root directory.

__init__(self, *args, **kwargs) special

Initiate the Pydantic object and initialize the Kubernetes clients.

Parameters:

Name Type Description Default
*args Any

The positional arguments to pass to the Pydantic object.

()
**kwargs Any

The keyword arguments to pass to the Pydantic object.

{}
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initiate the Pydantic object and initialize the Kubernetes clients.

    Args:
        *args: The positional arguments to pass to the Pydantic object.
        **kwargs: The keyword arguments to pass to the Pydantic object.
    """
    super().__init__(*args, **kwargs)
    self._initialize_k8s_clients()
check_required_attributes(values) classmethod

Pydantic root_validator.

This ensures that both deployment_name and kubernetes_context are set and raises an error with a custom error message otherwise.

Parameters:

Name Type Description Default
values Dict[str, Any]

Values passed to the Pydantic constructor.

required

Exceptions:

Type Description
StackComponentInterfaceError

if either deployment_name or kubernetes_context is not defined.

Returns:

Type Description
Dict[str, Any]

Values passed to the Pydantic constructor.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
@root_validator(skip_on_failure=False)
def check_required_attributes(
    cls, values: Dict[str, Any]
) -> Dict[str, Any]:
    """Pydantic root_validator.

    This ensures that both `deployment_name` and `kubernetes_context` are
    set and raises an error with a custom error message otherwise.

    Args:
        values: Values passed to the Pydantic constructor.

    Raises:
        StackComponentInterfaceError: if either `deployment_name` or
            `kubernetes_context` is not defined.

    Returns:
        Values passed to the Pydantic constructor.
    """
    usage_note = (
        "Note: the `kubernetes` metadata store flavor is a special "
        "subtype of the `mysql` metadata store that deploys a fresh "
        "MySQL database within your Kubernetes cluster when running "
        "`zenml stack up`. "
        "If you already have a MySQL database running in your cluster "
        "(or elsewhere), simply use the `mysql` metadata store flavor "
        "instead."
    )

    for required_field in ("deployment_name", "kubernetes_context"):
        if required_field not in values:
            raise StackComponentInterfaceError(
                f"Required field `{required_field}` missing for "
                "`KubernetesMetadataStore`. " + usage_note
            )

    return values
deprovision(self)

Deprovision the metadata store by deleting the MySQL deployment.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def deprovision(self) -> None:
    """Deprovision the metadata store by deleting the MySQL deployment."""
    logger.info("Deleting Kubernetes MySQL metadata store...")
    self.suspend()
    kube_utils.delete_deployment(
        apps_api=self._k8s_apps_api,
        deployment_name=self.deployment_name,
        namespace=self.kubernetes_namespace,
    )
get_tfx_metadata_config(self)

Return tfx metadata config for the Kubernetes metadata store.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

The tfx metadata config.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config for the Kubernetes metadata store.

    Returns:
        The tfx metadata config.
    """
    config = MySQLDatabaseConfig(
        host=self.host,
        port=self.port,
        database=DEFAULT_KUBERNETES_MYSQL_DATABASE,
        user=DEFAULT_KUBERNETES_MYSQL_USERNAME,
        password=DEFAULT_KUBERNETES_MYSQL_PASSWORD,
    )
    connection_config = metadata_store_pb2.ConnectionConfig(mysql=config)
    return connection_config
provision(self)

Provision the metadata store.

Creates a deployment with a MySQL database running in it.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def provision(self) -> None:
    """Provision the metadata store.

    Creates a deployment with a MySQL database running in it.
    """
    logger.info("Provisioning Kubernetes MySQL metadata store...")
    kube_utils.create_namespace(
        core_api=self._k8s_core_api, namespace=self.kubernetes_namespace
    )
    kube_utils.create_mysql_deployment(
        core_api=self._k8s_core_api,
        apps_api=self._k8s_apps_api,
        namespace=self.kubernetes_namespace,
        storage_capacity=self.storage_capacity,
        deployment_name=self.deployment_name,
    )

    # wait a bit, then make sure deployment pod is alive and running.
    logger.info("Trying to reach Kubernetes MySQL metadata store pod...")
    time.sleep(10)
    kube_utils.wait_pod(
        core_api=self._k8s_core_api,
        pod_name=self.pod_name,
        namespace=self.kubernetes_namespace,
        exit_condition_lambda=kube_utils.pod_is_not_pending,
    )
    logger.info("Kubernetes MySQL metadata store pod is up and running.")
resume(self)

Resumes the metadata store.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def resume(self) -> None:
    """Resumes the metadata store."""
    self.start_metadata_daemon()
    self.wait_until_metadata_store_ready(
        timeout=DEFAULT_KUBERNETES_METADATA_DAEMON_TIMEOUT
    )
start_metadata_daemon(self)

Starts a daemon process that forwards ports.

This is so the MySQL database in the Kubernetes cluster is accessible on the localhost.

Exceptions:

Type Description
ProvisioningError

if the daemon fails to start.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def start_metadata_daemon(self) -> None:
    """Starts a daemon process that forwards ports.

    This is so the MySQL database in the Kubernetes cluster is accessible
    on the localhost.

    Raises:
        ProvisioningError: if the daemon fails to start.
    """
    command = [
        "kubectl",
        "--context",
        self.kubernetes_context,
        "--namespace",
        self.kubernetes_namespace,
        "port-forward",
        f"svc/{self.deployment_name}",
        f"{self.port}:{self.port}",
    ]
    if sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Kubernetes Metadata locally, "
            "please run '%s' in a separate command line shell.",
            self.port,
            " ".join(command),
        )
    elif not networking_utils.port_available(self.port):
        raise ProvisioningError(
            f"Unable to port-forward Kubernetes Metadata to local "
            f"port {self.port} because the port is occupied. In order to "
            f"access the Kubernetes Metadata locally, please "
            f"change the metadata store configuration to use an available "
            f"port or stop the other process currently using the port."
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Forwards the port of the Kubernetes metadata store pod ."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=self._pid_file_path,
            log_file=self._log_file,
        )
        logger.info(
            "Started Kubernetes Metadata daemon (check the daemon"
            "logs at %s in case you're not able to access the pipeline"
            "metadata).",
            self._log_file,
        )
stop_metadata_daemon(self)

Stops the Kubernetes metadata daemon process if it is running.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def stop_metadata_daemon(self) -> None:
    """Stops the Kubernetes metadata daemon process if it is running."""
    if sys.platform != "win32" and fileio.exists(self._pid_file_path):
        from zenml.utils import daemon

        daemon.stop_daemon(self._pid_file_path)
        fileio.remove(self._pid_file_path)
suspend(self)

Suspends the metadata store.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def suspend(self) -> None:
    """Suspends the metadata store."""
    self.stop_metadata_daemon()
wait_until_metadata_store_ready(self, timeout)

Waits until the metadata store connection is ready.

Potentially an irrecoverable error could occur or the timeout could expire, so it checks for this.

Parameters:

Name Type Description Default
timeout int

The maximum time to wait for the metadata store to be ready.

required

Exceptions:

Type Description
RuntimeError

if the metadata store is not ready after the timeout

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def wait_until_metadata_store_ready(self, timeout: int) -> None:
    """Waits until the metadata store connection is ready.

    Potentially an irrecoverable error could occur or the timeout could
    expire, so it checks for this.

    Args:
        timeout: The maximum time to wait for the metadata store to be
            ready.

    Raises:
        RuntimeError: if the metadata store is not ready after the timeout
    """
    logger.info(
        "Waiting for the Kubernetes metadata store to be ready (this "
        "might take a few minutes)."
    )
    while True:
        try:
            # it doesn't matter what we call here as long as it exercises
            # the MLMD connection
            self.get_pipelines()
            break
        except Exception as e:
            logger.info(
                "The Kubernetes metadata store is not ready yet. Waiting "
                "for 10 seconds..."
            )
            if timeout <= 0:
                raise RuntimeError(
                    f"An unexpected error was encountered while waiting "
                    f"for the Kubernetes metadata store to be functional: "
                    f"{str(e)}"
                ) from e
            timeout -= 10
            time.sleep(10)

    logger.info("The Kubernetes metadata store is functional.")

orchestrators special

Kubernetes-native orchestration.

dag_runner

DAG (Directed Acyclic Graph) Runners.

NodeStatus (Enum)

Status of the execution of a node.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class NodeStatus(Enum):
    """Status of the execution of a node."""

    WAITING = "Waiting"
    RUNNING = "Running"
    COMPLETED = "Completed"
ThreadedDagRunner

Multi-threaded DAG Runner.

This class expects a DAG of strings in adjacency list representation, as well as a custom run_fn as input, then calls run_fn(node) for each string node in the DAG.

Steps that can be executed in parallel will be started in separate threads.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class ThreadedDagRunner:
    """Multi-threaded DAG Runner.

    This class expects a DAG of strings in adjacency list representation, as
    well as a custom `run_fn` as input, then calls `run_fn(node)` for each
    string node in the DAG.

    Steps that can be executed in parallel will be started in separate threads.
    """

    def __init__(
        self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
    ) -> None:
        """Define attributes and initialize all nodes in waiting state.

        Args:
            dag: Adjacency list representation of a DAG.
                E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
                `dag={2: [1], 3: [1], 4: [2, 3]}`
            run_fn: A function `run_fn(node)` that runs a single node
        """
        self.dag = dag
        self.reversed_dag = reverse_dag(dag)
        self.run_fn = run_fn
        self.nodes = dag.keys()
        self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
        self._lock = threading.Lock()

    def _can_run(self, node: str) -> bool:
        """Determine whether a node is ready to be run.

        This is the case if the node has not run yet and all of its upstream
        node have already completed.

        Args:
            node: The node.

        Returns:
            True if the node can run else False.
        """
        # Check that node has not run yet.
        if not self.node_states[node] == NodeStatus.WAITING:
            return False

        # Check that all upstream nodes of this node have already completed.
        for upstream_node in self.dag[node]:
            if not self.node_states[upstream_node] == NodeStatus.COMPLETED:
                return False

        return True

    def _run_node(self, node: str) -> None:
        """Run a single node.

        Calls the user-defined run_fn, then calls `self._finish_node`.

        Args:
            node: The node.
        """
        self.run_fn(node)
        self._finish_node(node)

    def _run_node_in_thread(self, node: str) -> threading.Thread:
        """Run a single node in a separate thread.

        First updates the node status to running.
        Then calls self._run_node() in a new thread and returns the thread.

        Args:
            node: The node.

        Returns:
            The thread in which the node was run.
        """
        # Update node status to running.
        assert self.node_states[node] == NodeStatus.WAITING
        with self._lock:
            self.node_states[node] = NodeStatus.RUNNING

        # Run node in new thread.
        thread = threading.Thread(target=self._run_node, args=(node,))
        thread.start()
        return thread

    def _finish_node(self, node: str) -> None:
        """Finish a node run.

        First updates the node status to completed.
        Then starts all other nodes that can now be run and waits for them.

        Args:
            node: The node.
        """
        # Update node status to completed.
        assert self.node_states[node] == NodeStatus.RUNNING
        with self._lock:
            self.node_states[node] = NodeStatus.COMPLETED

        # Run downstream nodes.
        threads = []
        for downstram_node in self.reversed_dag[node]:
            if self._can_run(downstram_node):
                thread = self._run_node_in_thread(downstram_node)
                threads.append(thread)

        # Wait for all downstream nodes to complete.
        for thread in threads:
            thread.join()

    def run(self) -> None:
        """Call `self.run_fn` on all nodes in `self.dag`.

        The order of execution is determined using topological sort.
        Each node is run in a separate thread to enable parallelism.
        """
        # Run all nodes that can be started immediately.
        # These will, in turn, start other nodes once all of their respective
        # upstream nodes have completed.
        threads = []
        for node in self.nodes:
            if self._can_run(node):
                thread = self._run_node_in_thread(node)
                threads.append(thread)

        # Wait till all nodes have completed.
        for thread in threads:
            thread.join()

        # Make sure all nodes were run, otherwise print a warning.
        for node in self.nodes:
            if self.node_states[node] == NodeStatus.WAITING:
                upstream_nodes = self.dag[node]
                logger.warning(
                    f"Node `{node}` was never run, because it was still"
                    f" waiting for the following nodes: `{upstream_nodes}`."
                )
__init__(self, dag, run_fn) special

Define attributes and initialize all nodes in waiting state.

Parameters:

Name Type Description Default
dag Dict[str, List[str]]

Adjacency list representation of a DAG. E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as dag={2: [1], 3: [1], 4: [2, 3]}

required
run_fn Callable[[str], Any]

A function run_fn(node) that runs a single node

required
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def __init__(
    self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
) -> None:
    """Define attributes and initialize all nodes in waiting state.

    Args:
        dag: Adjacency list representation of a DAG.
            E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
            `dag={2: [1], 3: [1], 4: [2, 3]}`
        run_fn: A function `run_fn(node)` that runs a single node
    """
    self.dag = dag
    self.reversed_dag = reverse_dag(dag)
    self.run_fn = run_fn
    self.nodes = dag.keys()
    self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
    self._lock = threading.Lock()
run(self)

Call self.run_fn on all nodes in self.dag.

The order of execution is determined using topological sort. Each node is run in a separate thread to enable parallelism.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def run(self) -> None:
    """Call `self.run_fn` on all nodes in `self.dag`.

    The order of execution is determined using topological sort.
    Each node is run in a separate thread to enable parallelism.
    """
    # Run all nodes that can be started immediately.
    # These will, in turn, start other nodes once all of their respective
    # upstream nodes have completed.
    threads = []
    for node in self.nodes:
        if self._can_run(node):
            thread = self._run_node_in_thread(node)
            threads.append(thread)

    # Wait till all nodes have completed.
    for thread in threads:
        thread.join()

    # Make sure all nodes were run, otherwise print a warning.
    for node in self.nodes:
        if self.node_states[node] == NodeStatus.WAITING:
            upstream_nodes = self.dag[node]
            logger.warning(
                f"Node `{node}` was never run, because it was still"
                f" waiting for the following nodes: `{upstream_nodes}`."
            )
reverse_dag(dag)

Reverse a DAG.

Parameters:

Name Type Description Default
dag Dict[str, List[str]]

Adjacency list representation of a DAG.

required

Returns:

Type Description
Dict[str, List[str]]

Adjacency list representation of the reversed DAG.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def reverse_dag(dag: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """Reverse a DAG.

    Args:
        dag: Adjacency list representation of a DAG.

    Returns:
        Adjacency list representation of the reversed DAG.
    """
    reversed_dag = defaultdict(list)

    # Reverse all edges in the graph.
    for node, upstream_nodes in dag.items():
        for upstream_node in upstream_nodes:
            reversed_dag[upstream_node].append(node)

    # Add nodes without incoming edges back in.
    for node in dag:
        if node not in reversed_dag:
            reversed_dag[node] = []

    return reversed_dag
kube_utils

Utilities for Kubernetes related functions.

Internal interface: no backwards compatibility guarantees. Adjusted from https://github.com/tensorflow/tfx/blob/master/tfx/utils/kube_utils.py.

PodPhase (Enum)

Phase of the Kubernetes pod.

Pod phases are defined in https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
class PodPhase(enum.Enum):
    """Phase of the Kubernetes pod.

    Pod phases are defined in
    https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
    """

    PENDING = "Pending"
    RUNNING = "Running"
    SUCCEEDED = "Succeeded"
    FAILED = "Failed"
    UNKNOWN = "Unknown"
create_edit_service_account(core_api, rbac_api, service_account_name, namespace, cluster_role_binding_name='zenml-edit')

Create a new Kubernetes service account with "edit" rights.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of Core V1 API of Kubernetes API.

required
rbac_api RbacAuthorizationV1Api

Client of Rbac Authorization V1 API of Kubernetes API.

required
service_account_name str

Name of the service account.

required
namespace str

Kubernetes namespace. Defaults to "default".

required
cluster_role_binding_name str

Name of the cluster role binding. Defaults to "zenml-edit".

'zenml-edit'
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_edit_service_account(
    core_api: k8s_client.CoreV1Api,
    rbac_api: k8s_client.RbacAuthorizationV1Api,
    service_account_name: str,
    namespace: str,
    cluster_role_binding_name: str = "zenml-edit",
) -> None:
    """Create a new Kubernetes service account with "edit" rights.

    Args:
        core_api: Client of Core V1 API of Kubernetes API.
        rbac_api: Client of Rbac Authorization V1 API of Kubernetes API.
        service_account_name: Name of the service account.
        namespace: Kubernetes namespace. Defaults to "default".
        cluster_role_binding_name: Name of the cluster role binding.
            Defaults to "zenml-edit".
    """
    crb_manifest = build_cluster_role_binding_manifest_for_service_account(
        name=cluster_role_binding_name,
        role_name="edit",
        service_account_name=service_account_name,
        namespace=namespace,
    )
    _if_not_exists(rbac_api.create_cluster_role_binding)(body=crb_manifest)

    sa_manifest = build_service_account_manifest(
        name=service_account_name, namespace=namespace
    )
    _if_not_exists(core_api.create_namespaced_service_account)(
        namespace=namespace,
        body=sa_manifest,
    )
create_mysql_deployment(core_api, apps_api, deployment_name, namespace, storage_capacity='10Gi', volume_name='mysql-pv-volume', volume_claim_name='mysql-pv-claim')

Create a Kubernetes deployment with a MySQL database running on it.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of Core V1 API of Kubernetes API.

required
apps_api AppsV1Api

Client of Apps V1 API of Kubernetes API.

required
namespace str

Kubernetes namespace. Defaults to "default".

required
storage_capacity str

Storage capacity of the database. Defaults to "10Gi".

'10Gi'
deployment_name str

Name of the deployment. Defaults to "mysql".

required
volume_name str

Name of the persistent volume. Defaults to "mysql-pv-volume".

'mysql-pv-volume'
volume_claim_name str

Name of the persistent volume claim. Defaults to "mysql-pv-claim".

'mysql-pv-claim'
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_mysql_deployment(
    core_api: k8s_client.CoreV1Api,
    apps_api: k8s_client.AppsV1Api,
    deployment_name: str,
    namespace: str,
    storage_capacity: str = "10Gi",
    volume_name: str = "mysql-pv-volume",
    volume_claim_name: str = "mysql-pv-claim",
) -> None:
    """Create a Kubernetes deployment with a MySQL database running on it.

    Args:
        core_api: Client of Core V1 API of Kubernetes API.
        apps_api: Client of Apps V1 API of Kubernetes API.
        namespace: Kubernetes namespace. Defaults to "default".
        storage_capacity: Storage capacity of the database.
            Defaults to `"10Gi"`.
        deployment_name: Name of the deployment. Defaults to "mysql".
        volume_name: Name of the persistent volume.
            Defaults to `"mysql-pv-volume"`.
        volume_claim_name: Name of the persistent volume claim.
            Defaults to `"mysql-pv-claim"`.
    """
    pvc_manifest = build_persistent_volume_claim_manifest(
        name=volume_claim_name,
        namespace=namespace,
        storage_request=storage_capacity,
    )
    _if_not_exists(core_api.create_namespaced_persistent_volume_claim)(
        namespace=namespace,
        body=pvc_manifest,
    )
    pv_manifest = build_persistent_volume_manifest(
        name=volume_name, storage_capacity=storage_capacity
    )
    _if_not_exists(core_api.create_persistent_volume)(body=pv_manifest)
    deployment_manifest = build_mysql_deployment_manifest(
        name=deployment_name,
        namespace=namespace,
        pv_claim_name=volume_claim_name,
    )
    _if_not_exists(apps_api.create_namespaced_deployment)(
        body=deployment_manifest, namespace=namespace
    )
    service_manifest = build_mysql_service_manifest(
        name=deployment_name, namespace=namespace
    )
    _if_not_exists(core_api.create_namespaced_service)(
        namespace=namespace, body=service_manifest
    )
create_namespace(core_api, namespace)

Create a Kubernetes namespace.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of Core V1 API of Kubernetes API.

required
namespace str

Kubernetes namespace. Defaults to "default".

required
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_namespace(core_api: k8s_client.CoreV1Api, namespace: str) -> None:
    """Create a Kubernetes namespace.

    Args:
        core_api: Client of Core V1 API of Kubernetes API.
        namespace: Kubernetes namespace. Defaults to "default".
    """
    manifest = build_namespace_manifest(namespace)
    _if_not_exists(core_api.create_namespace)(body=manifest)
delete_deployment(apps_api, deployment_name, namespace)

Delete a Kubernetes deployment.

Parameters:

Name Type Description Default
apps_api AppsV1Api

Client of Apps V1 API of Kubernetes API.

required
deployment_name str

Name of the deployment to be deleted.

required
namespace str

Kubernetes namespace containing the deployment.

required
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def delete_deployment(
    apps_api: k8s_client.AppsV1Api, deployment_name: str, namespace: str
) -> None:
    """Delete a Kubernetes deployment.

    Args:
        apps_api: Client of Apps V1 API of Kubernetes API.
        deployment_name: Name of the deployment to be deleted.
        namespace: Kubernetes namespace containing the deployment.
    """
    options = k8s_client.V1DeleteOptions()
    apps_api.delete_namespaced_deployment(
        name=deployment_name,
        namespace=namespace,
        body=options,
        propagation_policy="Foreground",
    )
get_pod(core_api, pod_name, namespace)

Get a pod from Kubernetes metadata API.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of CoreV1Api of Kubernetes API.

required
pod_name str

The name of the pod.

required
namespace str

The namespace of the pod.

required

Exceptions:

Type Description
RuntimeError

When it sees unexpected errors from Kubernetes API.

Returns:

Type Description
Optional[kubernetes.client.models.v1_pod.V1Pod]

The found pod object. None if it's not found.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def get_pod(
    core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str
) -> Optional[k8s_client.V1Pod]:
    """Get a pod from Kubernetes metadata API.

    Args:
        core_api: Client of `CoreV1Api` of Kubernetes API.
        pod_name: The name of the pod.
        namespace: The namespace of the pod.

    Raises:
        RuntimeError: When it sees unexpected errors from Kubernetes API.

    Returns:
        The found pod object. None if it's not found.
    """
    try:
        return core_api.read_namespaced_pod(name=pod_name, namespace=namespace)
    except k8s_client.rest.ApiException as e:
        if e.status == 404:
            return None
        raise RuntimeError from e
is_inside_kubernetes()

Check whether we are inside a Kubernetes cluster or on a remote host.

Returns:

Type Description
bool

True if inside a Kubernetes cluster, else False.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def is_inside_kubernetes() -> bool:
    """Check whether we are inside a Kubernetes cluster or on a remote host.

    Returns:
        True if inside a Kubernetes cluster, else False.
    """
    try:
        k8s_config.load_incluster_config()
        return True
    except k8s_config.ConfigException:
        return False
load_kube_config(context=None)

Load the Kubernetes client config.

Depending on the environment (whether it is inside the running Kubernetes cluster or remote host), different location will be searched for the config file.

Parameters:

Name Type Description Default
context Optional[str]

Name of the Kubernetes context. If not provided, uses the currently active context.

None
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def load_kube_config(context: Optional[str] = None) -> None:
    """Load the Kubernetes client config.

    Depending on the environment (whether it is inside the running Kubernetes
    cluster or remote host), different location will be searched for the config
    file.

    Args:
        context: Name of the Kubernetes context. If not provided, uses the
            currently active context.
    """
    try:
        k8s_config.load_incluster_config()
    except k8s_config.ConfigException:
        k8s_config.load_kube_config(context=context)
pod_failed(pod)

Check if pod status is 'Failed'.

Parameters:

Name Type Description Default
pod V1Pod

Kubernetes pod.

required

Returns:

Type Description
bool

True if pod status is 'Failed' else False.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_failed(pod: k8s_client.V1Pod) -> bool:
    """Check if pod status is 'Failed'.

    Args:
        pod: Kubernetes pod.

    Returns:
        True if pod status is 'Failed' else False.
    """
    return pod.status.phase == PodPhase.FAILED.value  # type: ignore[no-any-return]
pod_is_done(pod)

Check if pod status is 'Succeeded'.

Parameters:

Name Type Description Default
pod V1Pod

Kubernetes pod.

required

Returns:

Type Description
bool

True if pod status is 'Succeeded' else False.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_done(pod: k8s_client.V1Pod) -> bool:
    """Check if pod status is 'Succeeded'.

    Args:
        pod: Kubernetes pod.

    Returns:
        True if pod status is 'Succeeded' else False.
    """
    return pod.status.phase == PodPhase.SUCCEEDED.value  # type: ignore[no-any-return]
pod_is_not_pending(pod)

Check if pod status is not 'Pending'.

Parameters:

Name Type Description Default
pod V1Pod

Kubernetes pod.

required

Returns:

Type Description
bool

False if the pod status is 'Pending' else True.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_not_pending(pod: k8s_client.V1Pod) -> bool:
    """Check if pod status is not 'Pending'.

    Args:
        pod: Kubernetes pod.

    Returns:
        False if the pod status is 'Pending' else True.
    """
    return pod.status.phase != PodPhase.PENDING.value  # type: ignore[no-any-return]
sanitize_pod_name(pod_name)

Sanitize pod names so they conform to Kubernetes pod naming convention.

Parameters:

Name Type Description Default
pod_name str

Arbitrary input pod name.

required

Returns:

Type Description
str

Sanitized pod name.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def sanitize_pod_name(pod_name: str) -> str:
    """Sanitize pod names so they conform to Kubernetes pod naming convention.

    Args:
        pod_name: Arbitrary input pod name.

    Returns:
        Sanitized pod name.
    """
    pod_name = re.sub(r"[^a-z0-9-]", "-", pod_name.lower())
    pod_name = re.sub(r"^[-]+", "", pod_name)
    return re.sub(r"[-]+", "-", pod_name)
wait_pod(core_api, pod_name, namespace, exit_condition_lambda, timeout_sec=0, exponential_backoff=False, stream_logs=False)

Wait for a pod to meet an exit condition.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of CoreV1Api of Kubernetes API.

required
pod_name str

The name of the pod.

required
namespace str

The namespace of the pod.

required
exit_condition_lambda Callable[[kubernetes.client.models.v1_pod.V1Pod], bool]

A lambda which will be called periodically to wait for a pod to exit. The function returns True to exit.

required
timeout_sec int

Timeout in seconds to wait for pod to reach exit condition, or 0 to wait for an unlimited duration. Defaults to unlimited.

0
exponential_backoff bool

Whether to use exponential back off for polling. Defaults to False.

False
stream_logs bool

Whether to stream the pod logs to zenml.logger.info(). Defaults to False.

False

Exceptions:

Type Description
RuntimeError

when the function times out.

Returns:

Type Description
V1Pod

The pod object which meets the exit condition.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def wait_pod(
    core_api: k8s_client.CoreV1Api,
    pod_name: str,
    namespace: str,
    exit_condition_lambda: Callable[[k8s_client.V1Pod], bool],
    timeout_sec: int = 0,
    exponential_backoff: bool = False,
    stream_logs: bool = False,
) -> k8s_client.V1Pod:
    """Wait for a pod to meet an exit condition.

    Args:
        core_api: Client of `CoreV1Api` of Kubernetes API.
        pod_name: The name of the pod.
        namespace: The namespace of the pod.
        exit_condition_lambda: A lambda
            which will be called periodically to wait for a pod to exit. The
            function returns True to exit.
        timeout_sec: Timeout in seconds to wait for pod to reach exit
            condition, or 0 to wait for an unlimited duration.
            Defaults to unlimited.
        exponential_backoff: Whether to use exponential back off for polling.
            Defaults to False.
        stream_logs: Whether to stream the pod logs to
            `zenml.logger.info()`. Defaults to False.

    Raises:
        RuntimeError: when the function times out.

    Returns:
        The pod object which meets the exit condition.
    """
    start_time = datetime.datetime.utcnow()

    # Link to exponential back-off algorithm used here:
    # https://cloud.google.com/storage/docs/exponential-backoff
    backoff_interval = 1
    maximum_backoff = 32

    logged_lines = 0

    while True:
        resp = get_pod(core_api, pod_name, namespace)

        # Stream logs to `zenml.logger.info()`.
        # TODO: can we do this without parsing all logs every time?
        if stream_logs and pod_is_not_pending(resp):
            response = core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=namespace,
            )
            logs = response.splitlines()
            if len(logs) > logged_lines:
                for line in logs[logged_lines:]:
                    logger.info(line)
                logged_lines = len(logs)

        # Raise an error if the pod failed.
        if pod_failed(resp):
            raise RuntimeError(f"Pod `{namespace}:{pod_name}` failed.")

        # Check if pod is in desired state (e.g. finished / running / ...).
        if exit_condition_lambda(resp):
            return resp

        # Check if wait timed out.
        elapse_time = datetime.datetime.utcnow() - start_time
        if elapse_time.seconds >= timeout_sec and timeout_sec != 0:
            raise RuntimeError(
                f"Waiting for pod `{namespace}:{pod_name}` timed out after "
                f"{timeout_sec} seconds."
            )

        # Wait (using exponential backoff).
        time.sleep(backoff_interval)
        if exponential_backoff and backoff_interval < maximum_backoff:
            backoff_interval *= 2
kubernetes_orchestrator

Kubernetes-native orchestrator.

KubernetesOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator for running ZenML pipelines using native Kubernetes.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of a Docker image that should be used as the base for the image that will be run on Kubernetes pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML Docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

kubernetes_context Optional[str]

Optional name of a Kubernetes context to run pipelines in. If not set, the current active context will be used. You can find the active context by running kubectl config current-context.

kubernetes_namespace str

Name of the Kubernetes namespace to be used. If not provided, default namespace will be used.

synchronous bool

If True, running a pipeline using this orchestrator will block until all steps finished running on Kubernetes.

skip_config_loading bool

If True, don't load the Kubernetes context and clients. This is only useful for unit testing.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
class KubernetesOrchestrator(BaseOrchestrator):
    """Orchestrator for running ZenML pipelines using native Kubernetes.

    Attributes:
        custom_docker_base_image_name: Name of a Docker image that should be
            used as the base for the image that will be run on Kubernetes pods.
            If no custom image is given, a basic image of the active ZenML
            version will be used.
            **Note**: This image needs to have ZenML installed,
            otherwise the pipeline execution will fail. For that reason, you
            might want to extend the ZenML Docker images found here:
            https://hub.docker.com/r/zenmldocker/zenml/
        kubernetes_context: Optional name of a Kubernetes context to run
            pipelines in. If not set, the current active context will be used.
            You can find the active context by running `kubectl config
            current-context`.
        kubernetes_namespace: Name of the Kubernetes namespace to be used.
            If not provided, `default` namespace will be used.
        synchronous: If `True`, running a pipeline using this orchestrator will
            block until all steps finished running on Kubernetes.
        skip_config_loading: If `True`, don't load the Kubernetes context and
            clients. This is only useful for unit testing.
    """

    custom_docker_base_image_name: Optional[str] = None
    kubernetes_context: Optional[str] = None
    kubernetes_namespace: str = "zenml"
    synchronous: bool = False
    skip_config_loading: bool = False
    _k8s_core_api: k8s_client.CoreV1Api = None
    _k8s_batch_api: k8s_client.BatchV1beta1Api = None
    _k8s_rbac_api: k8s_client.RbacAuthorizationV1Api = None

    FLAVOR: ClassVar[str] = KUBERNETES_ORCHESTRATOR_FLAVOR

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initialize the Pydantic object the Kubernetes clients.

        Args:
            *args: The positional arguments to pass to the Pydantic object.
            **kwargs: The keyword arguments to pass to the Pydantic object.
        """
        super().__init__(*args, **kwargs)
        self._initialize_k8s_clients()

    def _initialize_k8s_clients(self) -> None:
        """Initialize the Kubernetes clients."""
        if self.skip_config_loading:
            return
        kube_utils.load_kube_config(context=self.kubernetes_context)
        self._k8s_core_api = k8s_client.CoreV1Api()
        self._k8s_batch_api = k8s_client.BatchV1beta1Api()
        self._k8s_rbac_api = k8s_client.RbacAuthorizationV1Api()

    def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
        """Get list of configured Kubernetes contexts and the active context.

        Raises:
            RuntimeError: if the Kubernetes configuration cannot be loaded.

        Returns:
            context_name: List of configured Kubernetes contexts
            active_context_name: Name of the active Kubernetes context.
        """
        try:
            contexts, active_context = k8s_config.list_kube_config_contexts()
        except k8s_config.config_exception.ConfigException as e:
            raise RuntimeError(
                "Could not load the Kubernetes configuration"
            ) from e

        context_names = [c["name"] for c in contexts]
        active_context_name = active_context["name"]
        return context_names, active_context_name

    @property
    def validator(self) -> Optional[StackValidator]:
        """Defines the validator that checks whether the stack is valid.

        Returns:
            Stack validator.
        """

        def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that the stack contains no local components.

            Args:
                stack: The stack.

            Returns:
                Whether the stack is valid or not.
                An explanation why the stack is invalid, if applicable.
            """
            container_registry = stack.container_registry

            # should not happen, because the stack validation takes care of
            # this, but just in case
            assert container_registry is not None

            if not self.skip_config_loading:
                contexts, active_context = self.get_kubernetes_contexts()
                if self.kubernetes_context not in contexts:
                    return False, (
                        f"Could not find a Kubernetes context named "
                        f"'{self.kubernetes_context}' in the local Kubernetes "
                        f"configuration. Please make sure that the Kubernetes "
                        f"cluster is running and that the kubeconfig file is "
                        f"configured correctly. To list all configured "
                        f"contexts, run:\n\n"
                        f"  `kubectl config get-contexts`\n"
                    )
                if self.kubernetes_context != active_context:
                    logger.warning(
                        f"The Kubernetes context '{self.kubernetes_context}' "
                        f"configured for the Kubernetes orchestrator is not "
                        f"the same as the active context in the local "
                        f"Kubernetes configuration. If this is not deliberate,"
                        f" you should update the orchestrator's "
                        f"`kubernetes_context` field by running:\n\n"
                        f"  `zenml orchestrator update {self.name} "
                        f"--kubernetes_context={active_context}`\n"
                        f"To list all configured contexts, run:\n\n"
                        f"  `kubectl config get-contexts`\n"
                        f"To set the active context to be the same as the one "
                        f"configured in the Kubernetes orchestrator and "
                        f"silence this warning, run:\n\n"
                        f"  `kubectl config use-context "
                        f"{self.kubernetes_context}`\n"
                    )

            # Check that all stack components are non-local.
            for stack_comp in stack.components.values():
                if stack_comp.local_path:
                    return False, (
                        f"The Kubernetes orchestrator currently only supports "
                        f"remote stacks, but the '{stack_comp.name}' "
                        f"{stack_comp.TYPE.value} is a local component. "
                        f"Please make sure to only use non-local stack "
                        f"components with a Kubernetes orchestrator."
                    )

            # if the orchestrator is remote, the container registry must
            # also be remote.
            if container_registry.is_local:
                return False, (
                    f"The Kubernetes orchestrator requires a remote container "
                    f"registry, but the '{container_registry.name}' container "
                    f"registry of your active stack points to a local URI "
                    f"'{container_registry.uri}'. Please make sure stacks "
                    f"with a Kubernetes orchestrator always contain remote "
                    f"container registries."
                )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_local_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Return the full Docker image name including registry and tag.

        Args:
            pipeline_name: Name of a ZenML pipeline.

        Returns:
            Docker image name.
        """
        container_registry = Repository().active_stack.container_registry
        assert container_registry
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/zenml-kubernetes:{pipeline_name}"

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Build a Docker image and upload it to the container registry.

        Args:
            pipeline: A ZenML pipeline.
            stack: A ZenML stack.
            runtime_configuration: The runtime configuration of the pipeline.
        """
        from zenml.utils import docker_utils

        image_name = self.get_docker_image_name(pipeline.name)

        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug("Kubernetes container requirements: %s", requirements)

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )

        assert stack.container_registry  # should never happen due to validation
        stack.container_registry.push_image(image_name)

        # Store the Docker image digest in the runtime configuration so it gets
        # tracked in the ZenStore
        image_digest = docker_utils.get_image_digest(image_name) or image_name
        runtime_configuration["docker_image"] = image_digest

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List["BaseStep"],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Run pipeline in Kubernetes.

        Args:
            sorted_steps: List of steps in execution order.
            pipeline: ZenML pipeline.
            pb2_pipeline: ZenML pipeline in TFX pb2 format.
            stack: ZenML stack.
            runtime_configuration: The runtime configuration of the pipeline.

        Raises:
            RuntimeError: If trying to run from a Jupyter notebook.
        """
        # First check whether the code is running in a notebook.
        if Environment.in_notebook():
            raise RuntimeError(
                "The Kubernetes orchestrator cannot run pipelines in a notebook "
                "environment. The reason is that it is non-trivial to create "
                "a Docker image of a notebook. Please consider refactoring "
                "your notebook cells into separate scripts in a Python module "
                "and run the code outside of a notebook when using this "
                "orchestrator."
            )

        assert runtime_configuration.run_name, "Run name must be set"

        for step in sorted_steps:
            if self.requires_resources_in_orchestration_environment(step):
                logger.warning(
                    "Specifying step resources is not yet supported for "
                    "the Kubernetes orchestrator, ignoring resource "
                    "configuration for step %s.",
                    step.name,
                )

        run_name = runtime_configuration.run_name
        pipeline_name = pipeline.name
        pod_name = kube_utils.sanitize_pod_name(run_name)

        # Get Docker image name (for all pods).
        image_name = self.get_docker_image_name(pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        # Get pipeline DAG as dict {"step": ["upstream_step_1", ...], ...}
        pipeline_dag: Dict[str, List[str]] = {
            step.name: self.get_upstream_step_names(step, pb2_pipeline)
            for step in sorted_steps
        }

        # Build entrypoint command and args for the orchestrator pod.
        # This will internally also build the command/args for all step pods.
        command = (
            KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
        )
        args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
            run_name=run_name,
            pipeline_name=pipeline_name,
            image_name=image_name,
            kubernetes_namespace=self.kubernetes_namespace,
            pb2_pipeline=pb2_pipeline,
            sorted_steps=sorted_steps,
            pipeline_dag=pipeline_dag,
        )

        # Authorize pod to run Kubernetes commands inside the cluster.
        service_account_name = "zenml-service-account"
        kube_utils.create_edit_service_account(
            core_api=self._k8s_core_api,
            rbac_api=self._k8s_rbac_api,
            service_account_name=service_account_name,
            namespace=self.kubernetes_namespace,
        )

        # Schedule as CRON job if CRON schedule is given.
        if runtime_configuration.schedule:
            if not runtime_configuration.schedule.cron_expression:
                raise RuntimeError(
                    "The Kubernetes orchestrator only supports scheduling via "
                    "CRON jobs, but the run was configured with a manual "
                    "schedule. Use `Schedule(cron_expression=...)` instead."
                )
            cron_expression = runtime_configuration.schedule.cron_expression
            cron_job_manifest = build_cron_job_manifest(
                cron_expression=cron_expression,
                run_name=run_name,
                pod_name=pod_name,
                pipeline_name=pipeline_name,
                image_name=image_name,
                command=command,
                args=args,
                service_account_name=service_account_name,
            )
            self._k8s_batch_api.create_namespaced_cron_job(
                body=cron_job_manifest, namespace=self.kubernetes_namespace
            )
            logger.info(
                f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
                f'`"{cron_expression}"`.'
            )
            return

        # Create and run the orchestrator pod.
        pod_manifest = build_pod_manifest(
            run_name=run_name,
            pod_name=pod_name,
            pipeline_name=pipeline_name,
            image_name=image_name,
            command=command,
            args=args,
            service_account_name=service_account_name,
        )
        self._k8s_core_api.create_namespaced_pod(
            namespace=self.kubernetes_namespace,
            body=pod_manifest,
        )

        # Wait for the orchestrator pod to finish and stream logs.
        if self.synchronous:
            logger.info("Waiting for Kubernetes orchestrator pod...")
            kube_utils.wait_pod(
                core_api=self._k8s_core_api,
                pod_name=pod_name,
                namespace=self.kubernetes_namespace,
                exit_condition_lambda=kube_utils.pod_is_done,
                stream_logs=True,
            )
        else:
            logger.info(
                f"Orchestration started asynchronously in pod "
                f"`{self.kubernetes_namespace}:{pod_name}`. "
                f"Run the following command to inspect the logs: "
                f"`kubectl logs {pod_name} -n {self.kubernetes_namespace}`."
            )
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Defines the validator that checks whether the stack is valid.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

Stack validator.

__init__(self, *args, **kwargs) special

Initialize the Pydantic object the Kubernetes clients.

Parameters:

Name Type Description Default
*args Any

The positional arguments to pass to the Pydantic object.

()
**kwargs Any

The keyword arguments to pass to the Pydantic object.

{}
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the Pydantic object the Kubernetes clients.

    Args:
        *args: The positional arguments to pass to the Pydantic object.
        **kwargs: The keyword arguments to pass to the Pydantic object.
    """
    super().__init__(*args, **kwargs)
    self._initialize_k8s_clients()
get_docker_image_name(self, pipeline_name)

Return the full Docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

Name of a ZenML pipeline.

required

Returns:

Type Description
str

Docker image name.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Return the full Docker image name including registry and tag.

    Args:
        pipeline_name: Name of a ZenML pipeline.

    Returns:
        Docker image name.
    """
    container_registry = Repository().active_stack.container_registry
    assert container_registry
    registry_uri = container_registry.uri.rstrip("/")
    return f"{registry_uri}/zenml-kubernetes:{pipeline_name}"
get_kubernetes_contexts(self)

Get list of configured Kubernetes contexts and the active context.

Exceptions:

Type Description
RuntimeError

if the Kubernetes configuration cannot be loaded.

Returns:

Type Description
context_name

List of configured Kubernetes contexts active_context_name: Name of the active Kubernetes context.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
    """Get list of configured Kubernetes contexts and the active context.

    Raises:
        RuntimeError: if the Kubernetes configuration cannot be loaded.

    Returns:
        context_name: List of configured Kubernetes contexts
        active_context_name: Name of the active Kubernetes context.
    """
    try:
        contexts, active_context = k8s_config.list_kube_config_contexts()
    except k8s_config.config_exception.ConfigException as e:
        raise RuntimeError(
            "Could not load the Kubernetes configuration"
        ) from e

    context_names = [c["name"] for c in contexts]
    active_context_name = active_context["name"]
    return context_names, active_context_name
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Run pipeline in Kubernetes.

Parameters:

Name Type Description Default
sorted_steps List[BaseStep]

List of steps in execution order.

required
pipeline BasePipeline

ZenML pipeline.

required
pb2_pipeline Pipeline

ZenML pipeline in TFX pb2 format.

required
stack Stack

ZenML stack.

required
runtime_configuration RuntimeConfiguration

The runtime configuration of the pipeline.

required

Exceptions:

Type Description
RuntimeError

If trying to run from a Jupyter notebook.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List["BaseStep"],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Run pipeline in Kubernetes.

    Args:
        sorted_steps: List of steps in execution order.
        pipeline: ZenML pipeline.
        pb2_pipeline: ZenML pipeline in TFX pb2 format.
        stack: ZenML stack.
        runtime_configuration: The runtime configuration of the pipeline.

    Raises:
        RuntimeError: If trying to run from a Jupyter notebook.
    """
    # First check whether the code is running in a notebook.
    if Environment.in_notebook():
        raise RuntimeError(
            "The Kubernetes orchestrator cannot run pipelines in a notebook "
            "environment. The reason is that it is non-trivial to create "
            "a Docker image of a notebook. Please consider refactoring "
            "your notebook cells into separate scripts in a Python module "
            "and run the code outside of a notebook when using this "
            "orchestrator."
        )

    assert runtime_configuration.run_name, "Run name must be set"

    for step in sorted_steps:
        if self.requires_resources_in_orchestration_environment(step):
            logger.warning(
                "Specifying step resources is not yet supported for "
                "the Kubernetes orchestrator, ignoring resource "
                "configuration for step %s.",
                step.name,
            )

    run_name = runtime_configuration.run_name
    pipeline_name = pipeline.name
    pod_name = kube_utils.sanitize_pod_name(run_name)

    # Get Docker image name (for all pods).
    image_name = self.get_docker_image_name(pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    # Get pipeline DAG as dict {"step": ["upstream_step_1", ...], ...}
    pipeline_dag: Dict[str, List[str]] = {
        step.name: self.get_upstream_step_names(step, pb2_pipeline)
        for step in sorted_steps
    }

    # Build entrypoint command and args for the orchestrator pod.
    # This will internally also build the command/args for all step pods.
    command = (
        KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
    )
    args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
        run_name=run_name,
        pipeline_name=pipeline_name,
        image_name=image_name,
        kubernetes_namespace=self.kubernetes_namespace,
        pb2_pipeline=pb2_pipeline,
        sorted_steps=sorted_steps,
        pipeline_dag=pipeline_dag,
    )

    # Authorize pod to run Kubernetes commands inside the cluster.
    service_account_name = "zenml-service-account"
    kube_utils.create_edit_service_account(
        core_api=self._k8s_core_api,
        rbac_api=self._k8s_rbac_api,
        service_account_name=service_account_name,
        namespace=self.kubernetes_namespace,
    )

    # Schedule as CRON job if CRON schedule is given.
    if runtime_configuration.schedule:
        if not runtime_configuration.schedule.cron_expression:
            raise RuntimeError(
                "The Kubernetes orchestrator only supports scheduling via "
                "CRON jobs, but the run was configured with a manual "
                "schedule. Use `Schedule(cron_expression=...)` instead."
            )
        cron_expression = runtime_configuration.schedule.cron_expression
        cron_job_manifest = build_cron_job_manifest(
            cron_expression=cron_expression,
            run_name=run_name,
            pod_name=pod_name,
            pipeline_name=pipeline_name,
            image_name=image_name,
            command=command,
            args=args,
            service_account_name=service_account_name,
        )
        self._k8s_batch_api.create_namespaced_cron_job(
            body=cron_job_manifest, namespace=self.kubernetes_namespace
        )
        logger.info(
            f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
            f'`"{cron_expression}"`.'
        )
        return

    # Create and run the orchestrator pod.
    pod_manifest = build_pod_manifest(
        run_name=run_name,
        pod_name=pod_name,
        pipeline_name=pipeline_name,
        image_name=image_name,
        command=command,
        args=args,
        service_account_name=service_account_name,
    )
    self._k8s_core_api.create_namespaced_pod(
        namespace=self.kubernetes_namespace,
        body=pod_manifest,
    )

    # Wait for the orchestrator pod to finish and stream logs.
    if self.synchronous:
        logger.info("Waiting for Kubernetes orchestrator pod...")
        kube_utils.wait_pod(
            core_api=self._k8s_core_api,
            pod_name=pod_name,
            namespace=self.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            stream_logs=True,
        )
    else:
        logger.info(
            f"Orchestration started asynchronously in pod "
            f"`{self.kubernetes_namespace}:{pod_name}`. "
            f"Run the following command to inspect the logs: "
            f"`kubectl logs {pod_name} -n {self.kubernetes_namespace}`."
        )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Build a Docker image and upload it to the container registry.

Parameters:

Name Type Description Default
pipeline BasePipeline

A ZenML pipeline.

required
stack Stack

A ZenML stack.

required
runtime_configuration RuntimeConfiguration

The runtime configuration of the pipeline.

required
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Build a Docker image and upload it to the container registry.

    Args:
        pipeline: A ZenML pipeline.
        stack: A ZenML stack.
        runtime_configuration: The runtime configuration of the pipeline.
    """
    from zenml.utils import docker_utils

    image_name = self.get_docker_image_name(pipeline.name)

    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug("Kubernetes container requirements: %s", requirements)

    docker_utils.build_docker_image(
        build_context_path=get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )

    assert stack.container_registry  # should never happen due to validation
    stack.container_registry.push_image(image_name)

    # Store the Docker image digest in the runtime configuration so it gets
    # tracked in the ZenStore
    image_digest = docker_utils.get_image_digest(image_name) or image_name
    runtime_configuration["docker_image"] = image_digest
kubernetes_orchestrator_entrypoint

Entrypoint of the Kubernetes master/orchestrator pod.

main()

Entrypoint of the k8s master/orchestrator pod.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def main() -> None:
    """Entrypoint of the k8s master/orchestrator pod."""
    # Log to the container's stdout so it can be streamed by the client.
    logger.info("Kubernetes orchestrator pod started.")

    # Parse / extract args.
    args = parse_args()
    pipeline_config = args.pipeline_config
    step_command = pipeline_config["step_command"]
    fixed_step_args = pipeline_config["fixed_step_args"]
    step_specific_args = pipeline_config["step_specific_args"]
    pipeline_dag = pipeline_config["pipeline_dag"]

    # Get Kubernetes Core API for running kubectl commands later.
    kube_utils.load_kube_config()
    core_api = k8s_client.CoreV1Api()

    # Patch run name (only needed for CRON scheduling)
    run_name = patch_run_name_for_cron_scheduling(
        args.run_name, fixed_step_args
    )

    def run_step_on_kubernetes(step_name: str) -> None:
        """Run a pipeline step in a separate Kubernetes pod.

        Args:
            step_name: Name of the step.
        """
        # Define Kubernetes pod name.
        pod_name = f"{run_name}-{step_name}"
        pod_name = kube_utils.sanitize_pod_name(pod_name)

        # Build list of args for this step.
        step_args = [*fixed_step_args, *step_specific_args[step_name]]

        # Define Kubernetes pod manifest.
        pod_manifest = build_pod_manifest(
            pod_name=pod_name,
            run_name=run_name,
            pipeline_name=args.pipeline_name,
            image_name=args.image_name,
            command=step_command,
            args=step_args,
        )

        # Create and run pod.
        core_api.create_namespaced_pod(
            namespace=args.kubernetes_namespace,
            body=pod_manifest,
        )

        # Wait for pod to finish.
        logger.info(f"Waiting for pod of step `{step_name}` to start...")
        kube_utils.wait_pod(
            core_api=core_api,
            pod_name=pod_name,
            namespace=args.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            stream_logs=True,
        )
        logger.info(f"Pod of step `{step_name}` completed.")

    ThreadedDagRunner(dag=pipeline_dag, run_fn=run_step_on_kubernetes).run()

    logger.info("Orchestration pod completed.")
parse_args()

Parse entrypoint arguments.

Returns:

Type Description
Namespace

Parsed args.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def parse_args() -> argparse.Namespace:
    """Parse entrypoint arguments.

    Returns:
        Parsed args.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--pipeline_name", type=str, required=True)
    parser.add_argument("--image_name", type=str, required=True)
    parser.add_argument("--kubernetes_namespace", type=str, required=True)
    parser.add_argument("--pipeline_config", type=json.loads, required=True)
    return parser.parse_args()
patch_run_name_for_cron_scheduling(run_name, fixed_step_args)

Adjust run name according to the Kubernetes orchestrator pod name.

This is required for scheduling via CRON jobs, since each job would otherwise have the same run name, which zenml does not support.

Parameters:

Name Type Description Default
run_name str

Initial run name.

required
fixed_step_args List[str]

Fixed entrypoint args for the step pods. We also need to patch the run name in there.

required

Returns:

Type Description
str

New unique run name.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def patch_run_name_for_cron_scheduling(
    run_name: str, fixed_step_args: List[str]
) -> str:
    """Adjust run name according to the Kubernetes orchestrator pod name.

    This is required for scheduling via CRON jobs, since each job would
    otherwise have the same run name, which zenml does not support.

    Args:
        run_name: Initial run name.
        fixed_step_args: Fixed entrypoint args for the step pods.
            We also need to patch the run name in there.

    Returns:
        New unique run name.
    """
    # Get name of the orchestrator pod.
    host_name = socket.gethostname()

    # If we are not running as CRON job, we don't need to do anything.
    if host_name == kube_utils.sanitize_pod_name(run_name):
        return run_name

    # Otherwise, define new run_name.
    job_id = host_name.split("-")[-1]
    run_name = f"{run_name}-{job_id}"

    # Then also adjust run_name in fixed_step_args.
    for i, arg in enumerate(fixed_step_args):
        if arg == "--run_name":
            fixed_step_args[i + 1] = run_name

    return run_name
kubernetes_orchestrator_entrypoint_configuration

Entrypoint configuration for the Kubernetes master/orchestrator pod.

KubernetesOrchestratorEntrypointConfiguration

Entrypoint configuration for the k8s master/orchestrator pod.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
class KubernetesOrchestratorEntrypointConfiguration:
    """Entrypoint configuration for the k8s master/orchestrator pod."""

    @classmethod
    def get_entrypoint_options(cls) -> Set[str]:
        """Gets all the options required for running this entrypoint.

        Returns:
            Entrypoint options.
        """
        options = {
            RUN_NAME_OPTION,
            PIPELINE_NAME_OPTION,
            IMAGE_NAME_OPTION,
            NAMESPACE_OPTION,
            PIPELINE_CONFIG_OPTION,
        }
        return options

    @classmethod
    def get_entrypoint_command(cls) -> List[str]:
        """Returns a command that runs the entrypoint module.

        Returns:
            Entrypoint command.
        """
        command = [
            "python",
            "-m",
            "zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
        ]
        return command

    @classmethod
    def get_entrypoint_arguments(
        cls,
        run_name: str,
        pipeline_name: str,
        image_name: str,
        kubernetes_namespace: str,
        pb2_pipeline: Pb2Pipeline,
        sorted_steps: List[BaseStep],
        pipeline_dag: Dict[str, List[str]],
    ) -> List[str]:
        """Gets all arguments that the entrypoint command should be called with.

        Args:
            run_name: Name of the ZenML run.
            pipeline_name: Name of the ZenML pipeline.
            image_name: Name of the Docker image.
            kubernetes_namespace: Name of the Kubernetes namespace.
            pb2_pipeline: ZenML pipeline in TFX pb2 format.
            sorted_steps: List of steps in execution order.
            pipeline_dag: For each step, list of steps that need to run before.

        Returns:
            List of entrypoint arguments.
        """

        def _get_step_args(step: BaseStep) -> List[str]:
            """Get the entrypoint args for a specific step.

            Args:
                step: ZenML step for which to get entrypoint args.

            Returns:
                Entrypoint args of the step.
            """
            return (
                KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                    **{RUN_NAME_OPTION: run_name},
                )
            )

        # Get name, command, and args for each step
        step_names = [step.name for step in sorted_steps]
        step_command = (
            KubernetesStepEntrypointConfiguration.get_entrypoint_command()
        )
        fixed_step_args = []
        if len(sorted_steps) > 0:
            first_step_args = _get_step_args(sorted_steps[0])
            fixed_step_args = split_step_args(first_step_args)[0]
        step_specific_args = {
            step.name: split_step_args(_get_step_args(step))[1]
            for step in sorted_steps
        }  # e.g.: {"trainer": train_step_args, ...}

        # Serialize all complex datatype args into a single JSON string
        pipeline_config = {
            "sorted_steps": step_names,
            "step_command": step_command,
            "fixed_step_args": fixed_step_args,
            "step_specific_args": step_specific_args,
            "pipeline_dag": pipeline_dag,
        }
        pipeline_config_json = json.dumps(pipeline_config)

        # Define entrypoint args.
        args = [
            f"--{RUN_NAME_OPTION}",
            run_name,
            f"--{PIPELINE_NAME_OPTION}",
            pipeline_name,
            f"--{IMAGE_NAME_OPTION}",
            image_name,
            f"--{NAMESPACE_OPTION}",
            kubernetes_namespace,
            f"--{PIPELINE_CONFIG_OPTION}",
            pipeline_config_json,
        ]

        return args
get_entrypoint_arguments(run_name, pipeline_name, image_name, kubernetes_namespace, pb2_pipeline, sorted_steps, pipeline_dag) classmethod

Gets all arguments that the entrypoint command should be called with.

Parameters:

Name Type Description Default
run_name str

Name of the ZenML run.

required
pipeline_name str

Name of the ZenML pipeline.

required
image_name str

Name of the Docker image.

required
kubernetes_namespace str

Name of the Kubernetes namespace.

required
pb2_pipeline Pipeline

ZenML pipeline in TFX pb2 format.

required
sorted_steps List[zenml.steps.base_step.BaseStep]

List of steps in execution order.

required
pipeline_dag Dict[str, List[str]]

For each step, list of steps that need to run before.

required

Returns:

Type Description
List[str]

List of entrypoint arguments.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
    cls,
    run_name: str,
    pipeline_name: str,
    image_name: str,
    kubernetes_namespace: str,
    pb2_pipeline: Pb2Pipeline,
    sorted_steps: List[BaseStep],
    pipeline_dag: Dict[str, List[str]],
) -> List[str]:
    """Gets all arguments that the entrypoint command should be called with.

    Args:
        run_name: Name of the ZenML run.
        pipeline_name: Name of the ZenML pipeline.
        image_name: Name of the Docker image.
        kubernetes_namespace: Name of the Kubernetes namespace.
        pb2_pipeline: ZenML pipeline in TFX pb2 format.
        sorted_steps: List of steps in execution order.
        pipeline_dag: For each step, list of steps that need to run before.

    Returns:
        List of entrypoint arguments.
    """

    def _get_step_args(step: BaseStep) -> List[str]:
        """Get the entrypoint args for a specific step.

        Args:
            step: ZenML step for which to get entrypoint args.

        Returns:
            Entrypoint args of the step.
        """
        return (
            KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
                step=step,
                pb2_pipeline=pb2_pipeline,
                **{RUN_NAME_OPTION: run_name},
            )
        )

    # Get name, command, and args for each step
    step_names = [step.name for step in sorted_steps]
    step_command = (
        KubernetesStepEntrypointConfiguration.get_entrypoint_command()
    )
    fixed_step_args = []
    if len(sorted_steps) > 0:
        first_step_args = _get_step_args(sorted_steps[0])
        fixed_step_args = split_step_args(first_step_args)[0]
    step_specific_args = {
        step.name: split_step_args(_get_step_args(step))[1]
        for step in sorted_steps
    }  # e.g.: {"trainer": train_step_args, ...}

    # Serialize all complex datatype args into a single JSON string
    pipeline_config = {
        "sorted_steps": step_names,
        "step_command": step_command,
        "fixed_step_args": fixed_step_args,
        "step_specific_args": step_specific_args,
        "pipeline_dag": pipeline_dag,
    }
    pipeline_config_json = json.dumps(pipeline_config)

    # Define entrypoint args.
    args = [
        f"--{RUN_NAME_OPTION}",
        run_name,
        f"--{PIPELINE_NAME_OPTION}",
        pipeline_name,
        f"--{IMAGE_NAME_OPTION}",
        image_name,
        f"--{NAMESPACE_OPTION}",
        kubernetes_namespace,
        f"--{PIPELINE_CONFIG_OPTION}",
        pipeline_config_json,
    ]

    return args
get_entrypoint_command() classmethod

Returns a command that runs the entrypoint module.

Returns:

Type Description
List[str]

Entrypoint command.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_command(cls) -> List[str]:
    """Returns a command that runs the entrypoint module.

    Returns:
        Entrypoint command.
    """
    command = [
        "python",
        "-m",
        "zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
    ]
    return command
get_entrypoint_options() classmethod

Gets all the options required for running this entrypoint.

Returns:

Type Description
Set[str]

Entrypoint options.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
    """Gets all the options required for running this entrypoint.

    Returns:
        Entrypoint options.
    """
    options = {
        RUN_NAME_OPTION,
        PIPELINE_NAME_OPTION,
        IMAGE_NAME_OPTION,
        NAMESPACE_OPTION,
        PIPELINE_CONFIG_OPTION,
    }
    return options
split_step_args(step_args)

Split step args into fixed and step-specific.

We want to have them separate so we can send the fixed args to the orchestrator pod only once.

Parameters:

Name Type Description Default
step_args List[str]

list of ALL step args. E.g. ["--arg1", "arg1_value", "--arg2", "arg2_value", ...].

required

Returns:

Type Description
Tuple[List[str], List[str]]

Tuple (fixed step args, step-specific args).

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
def split_step_args(step_args: List[str]) -> Tuple[List[str], List[str]]:
    """Split step args into fixed and step-specific.

    We want to have them separate so we can send the fixed args to the
    orchestrator pod only once.

    Args:
        step_args: list of ALL step args.
            E.g. ["--arg1", "arg1_value", "--arg2", "arg2_value", ...].

    Returns:
        Tuple (fixed step args, step-specific args).
    """
    fixed_args = []
    step_specific_args = []
    for i, arg in enumerate(step_args):
        if not arg.startswith("--"):  # arg is a value, not an option
            continue
        option_and_value = step_args[i : i + 2]  # e.g. ["--name", "Aria"]
        is_fixed = arg[2:] not in STEP_SPECIFIC_STEP_ENTRYPOINT_OPTIONS
        if is_fixed:
            fixed_args += option_and_value
        else:
            step_specific_args += option_and_value
    return fixed_args, step_specific_args
kubernetes_step_entrypoint_configuration

Entrypoint configuration for the Kubernetes worker/step pods.

KubernetesStepEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on Kubernetes.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
class KubernetesStepEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on Kubernetes."""

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """Kubernetes specific entrypoint options.

        The argument `RUN_NAME_OPTION` is needed for `get_run_name` to have
        consistent values between steps.

        Returns:
            Set of entrypoint options.
        """
        return {RUN_NAME_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: "BaseStep", *args: Any, **kwargs: Any
    ) -> List[str]:
        """Kubernetes specific entrypoint arguments.

        Sets the value for the `RUN_NAME_OPTION` argument.

        Args:
            step: ZenML step for which the entrypoint is built.
            args: additional (unused) arguments.
            kwargs: keyword args; needs to include `RUN_NAME_OPTION`.

        Returns:
            List of entrypoint arguments.
        """
        return [
            f"--{RUN_NAME_OPTION}",
            kwargs[RUN_NAME_OPTION],
        ]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the ZenML run name.

        Args:
            pipeline_name: Name of the ZenML pipeline (unused).

        Returns:
            ZenML run name.
        """
        job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
        return job_id
get_custom_entrypoint_arguments(step, *args, **kwargs) classmethod

Kubernetes specific entrypoint arguments.

Sets the value for the RUN_NAME_OPTION argument.

Parameters:

Name Type Description Default
step BaseStep

ZenML step for which the entrypoint is built.

required
args Any

additional (unused) arguments.

()
kwargs Any

keyword args; needs to include RUN_NAME_OPTION.

{}

Returns:

Type Description
List[str]

List of entrypoint arguments.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
    """Kubernetes specific entrypoint arguments.

    Sets the value for the `RUN_NAME_OPTION` argument.

    Args:
        step: ZenML step for which the entrypoint is built.
        args: additional (unused) arguments.
        kwargs: keyword args; needs to include `RUN_NAME_OPTION`.

    Returns:
        List of entrypoint arguments.
    """
    return [
        f"--{RUN_NAME_OPTION}",
        kwargs[RUN_NAME_OPTION],
    ]
get_custom_entrypoint_options() classmethod

Kubernetes specific entrypoint options.

The argument RUN_NAME_OPTION is needed for get_run_name to have consistent values between steps.

Returns:

Type Description
Set[str]

Set of entrypoint options.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """Kubernetes specific entrypoint options.

    The argument `RUN_NAME_OPTION` is needed for `get_run_name` to have
    consistent values between steps.

    Returns:
        Set of entrypoint options.
    """
    return {RUN_NAME_OPTION}
get_run_name(self, pipeline_name)

Returns the ZenML run name.

Parameters:

Name Type Description Default
pipeline_name str

Name of the ZenML pipeline (unused).

required

Returns:

Type Description
str

ZenML run name.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the ZenML run name.

    Args:
        pipeline_name: Name of the ZenML pipeline (unused).

    Returns:
        ZenML run name.
    """
    job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
    return job_id
manifest_utils

Utility functions for building manifests for k8s pods.

build_cluster_role_binding_manifest_for_service_account(name, role_name, service_account_name, namespace='default')

Build a manifest for a cluster role binding of a service account.

Parameters:

Name Type Description Default
name str

Name of the cluster role binding.

required
role_name str

Name of the role.

required
service_account_name str

Name of the service account.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'

Returns:

Type Description
Dict[str, Any]

Manifest for a cluster role binding of a service account.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cluster_role_binding_manifest_for_service_account(
    name: str,
    role_name: str,
    service_account_name: str,
    namespace: str = "default",
) -> Dict[str, Any]:
    """Build a manifest for a cluster role binding of a service account.

    Args:
        name: Name of the cluster role binding.
        role_name: Name of the role.
        service_account_name: Name of the service account.
        namespace: Kubernetes namespace. Defaults to "default".

    Returns:
        Manifest for a cluster role binding of a service account.
    """
    return {
        "apiVersion": "rbac.authorization.k8s.io/v1",
        "kind": "ClusterRoleBinding",
        "metadata": {"name": name},
        "subjects": [
            {
                "kind": "ServiceAccount",
                "name": service_account_name,
                "namespace": namespace,
            }
        ],
        "roleRef": {
            "kind": "ClusterRole",
            "name": role_name,
            "apiGroup": "rbac.authorization.k8s.io",
        },
    }
build_cron_job_manifest(cron_expression, pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)

Create a manifest for launching a pod as scheduled CRON job.

Parameters:

Name Type Description Default
cron_expression str

CRON job schedule expression, e.g. " * * *".

required
pod_name str

Name of the pod.

required
run_name str

Name of the ZenML run.

required
pipeline_name str

Name of the ZenML pipeline.

required
image_name str

Name of the Docker image.

required
command List[str]

Command to execute the entrypoint in the pod.

required
args List[str]

Arguments provided to the entrypoint command.

required
service_account_name Optional[str]

Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster.

None

Returns:

Type Description
Dict[str, Any]

CRON job manifest.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cron_job_manifest(
    cron_expression: str,
    pod_name: str,
    run_name: str,
    pipeline_name: str,
    image_name: str,
    command: List[str],
    args: List[str],
    service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
    """Create a manifest for launching a pod as scheduled CRON job.

    Args:
        cron_expression: CRON job schedule expression, e.g. "* * * * *".
        pod_name: Name of the pod.
        run_name: Name of the ZenML run.
        pipeline_name: Name of the ZenML pipeline.
        image_name: Name of the Docker image.
        command: Command to execute the entrypoint in the pod.
        args: Arguments provided to the entrypoint command.
        service_account_name: Optional name of a service account.
            Can be used to assign certain roles to a pod, e.g., to allow it to
            run Kubernetes commands from within the cluster.

    Returns:
        CRON job manifest.
    """
    pod_manifest = build_pod_manifest(
        pod_name=pod_name,
        run_name=run_name,
        pipeline_name=pipeline_name,
        image_name=image_name,
        command=command,
        args=args,
        service_account_name=service_account_name,
    )
    return {
        "apiVersion": "batch/v1beta1",
        "kind": "CronJob",
        "metadata": pod_manifest["metadata"],
        "spec": {
            "schedule": cron_expression,
            "jobTemplate": {
                "metadata": pod_manifest["metadata"],
                "spec": {"template": {"spec": pod_manifest["spec"]}},
            },
        },
    }
build_mysql_deployment_manifest(name='mysql', namespace='default', port=3306, pv_claim_name='mysql-pv-claim')

Build a manifest for deploying a MySQL database.

Parameters:

Name Type Description Default
name str

Name of the deployment. Defaults to "mysql".

'mysql'
namespace str

Kubernetes namespace. Defaults to "default".

'default'
port int

Port where MySQL is running. Defaults to 3306.

3306
pv_claim_name str

Name of the required persistent volume claim. Defaults to "mysql-pv-claim".

'mysql-pv-claim'

Returns:

Type Description
Dict[str, Any]

Manifest for deploying a MySQL database.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_deployment_manifest(
    name: str = "mysql",
    namespace: str = "default",
    port: int = 3306,
    pv_claim_name: str = "mysql-pv-claim",
) -> Dict[str, Any]:
    """Build a manifest for deploying a MySQL database.

    Args:
        name: Name of the deployment. Defaults to "mysql".
        namespace: Kubernetes namespace. Defaults to "default".
        port: Port where MySQL is running. Defaults to 3306.
        pv_claim_name: Name of the required persistent volume claim.
            Defaults to `"mysql-pv-claim"`.

    Returns:
        Manifest for deploying a MySQL database.
    """
    return {
        "apiVersion": "apps/v1",
        "kind": "Deployment",
        "metadata": {"name": name, "namespace": namespace},
        "spec": {
            "selector": {
                "matchLabels": {
                    "app": name,
                },
            },
            "strategy": {
                "type": "Recreate",
            },
            "template": {
                "metadata": {
                    "labels": {"app": name},
                },
                "spec": {
                    "containers": [
                        {
                            "image": "gcr.io/ml-pipeline/mysql:5.6",
                            "name": name,
                            "env": [
                                {
                                    "name": "MYSQL_ALLOW_EMPTY_PASSWORD",
                                    "value": '"true"',
                                }
                            ],
                            "ports": [{"containerPort": port, "name": name}],
                            "volumeMounts": [
                                {
                                    "name": "mysql-persistent-storage",
                                    "mountPath": "/var/lib/mysql",
                                }
                            ],
                        }
                    ],
                    "volumes": [
                        {
                            "name": "mysql-persistent-storage",
                            "persistentVolumeClaim": {
                                "claimName": pv_claim_name
                            },
                        }
                    ],
                },
            },
        },
    }
build_mysql_service_manifest(name='mysql', namespace='default', port=3306)

Build a manifest for a service relating to a deployed MySQL database.

Parameters:

Name Type Description Default
name str

Name of the service. Defaults to "mysql".

'mysql'
namespace str

Kubernetes namespace. Defaults to "default".

'default'
port int

Port where MySQL is running. Defaults to 3306.

3306

Returns:

Type Description
Dict[str, Any]

Manifest for the MySQL service.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_service_manifest(
    name: str = "mysql",
    namespace: str = "default",
    port: int = 3306,
) -> Dict[str, Any]:
    """Build a manifest for a service relating to a deployed MySQL database.

    Args:
        name: Name of the service. Defaults to "mysql".
        namespace: Kubernetes namespace. Defaults to "default".
        port: Port where MySQL is running. Defaults to 3306.

    Returns:
        Manifest for the MySQL service.
    """
    return {
        "apiVersion": "v1",
        "kind": "Service",
        "metadata": {
            "name": name,
            "namespace": namespace,
        },
        "spec": {
            "selector": {"app": "mysql"},
            "clusterIP": "None",
            "ports": [{"port": port}],
        },
    }
build_namespace_manifest(namespace)

Build the manifest for a new namespace.

Parameters:

Name Type Description Default
namespace str

Kubernetes namespace.

required

Returns:

Type Description
Dict[str, Any]

Manifest of the new namespace.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_namespace_manifest(namespace: str) -> Dict[str, Any]:
    """Build the manifest for a new namespace.

    Args:
        namespace: Kubernetes namespace.

    Returns:
        Manifest of the new namespace.
    """
    return {
        "apiVersion": "v1",
        "kind": "Namespace",
        "metadata": {
            "name": namespace,
        },
    }
build_persistent_volume_claim_manifest(name, namespace='default', storage_request='10Gi')

Build a manifest for a persistent volume claim.

Parameters:

Name Type Description Default
name str

Name of the persistent volume claim.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'
storage_request str

Size of the storage to request. Defaults to "10Gi".

'10Gi'

Returns:

Type Description
Dict[str, Any]

Manifest for a persistent volume claim.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_claim_manifest(
    name: str,
    namespace: str = "default",
    storage_request: str = "10Gi",
) -> Dict[str, Any]:
    """Build a manifest for a persistent volume claim.

    Args:
        name: Name of the persistent volume claim.
        namespace: Kubernetes namespace. Defaults to "default".
        storage_request: Size of the storage to request. Defaults to `"10Gi"`.

    Returns:
        Manifest for a persistent volume claim.
    """
    return {
        "apiVersion": "v1",
        "kind": "PersistentVolumeClaim",
        "metadata": {
            "name": name,
            "namespace": namespace,
        },
        "spec": {
            "storageClassName": "manual",
            "accessModes": ["ReadWriteOnce"],
            "resources": {
                "requests": {
                    "storage": storage_request,
                }
            },
        },
    }
build_persistent_volume_manifest(name, namespace='default', storage_capacity='10Gi', path='/mnt/data')

Build a manifest for a persistent volume.

Parameters:

Name Type Description Default
name str

Name of the persistent volume.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'
storage_capacity str

Storage capacity of the volume. Defaults to "10Gi".

'10Gi'
path str

Path where the volume is mounted. Defaults to "/mnt/data".

'/mnt/data'

Returns:

Type Description
Dict[str, Any]

Manifest for a persistent volume.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_manifest(
    name: str,
    namespace: str = "default",
    storage_capacity: str = "10Gi",
    path: str = "/mnt/data",
) -> Dict[str, Any]:
    """Build a manifest for a persistent volume.

    Args:
        name: Name of the persistent volume.
        namespace: Kubernetes namespace. Defaults to "default".
        storage_capacity: Storage capacity of the volume. Defaults to `"10Gi"`.
        path: Path where the volume is mounted. Defaults to `"/mnt/data"`.

    Returns:
        Manifest for a persistent volume.
    """
    return {
        "apiVersion": "v1",
        "kind": "PersistentVolume",
        "metadata": {
            "name": name,
            "namespace": namespace,
            "labels": {"type": "local"},
        },
        "spec": {
            "storageClassName": "manual",
            "capacity": {"storage": storage_capacity},
            "accessModes": ["ReadWriteOnce"],
            "hostPath": {"path": path},
        },
    }
build_pod_manifest(pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)

Build a Kubernetes pod manifest for a ZenML run or step.

Parameters:

Name Type Description Default
pod_name str

Name of the pod.

required
run_name str

Name of the ZenML run.

required
pipeline_name str

Name of the ZenML pipeline.

required
image_name str

Name of the Docker image.

required
command List[str]

Command to execute the entrypoint in the pod.

required
args List[str]

Arguments provided to the entrypoint command.

required
service_account_name Optional[str]

Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster.

None

Returns:

Type Description
Dict[str, Any]

Pod manifest.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_pod_manifest(
    pod_name: str,
    run_name: str,
    pipeline_name: str,
    image_name: str,
    command: List[str],
    args: List[str],
    service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
    """Build a Kubernetes pod manifest for a ZenML run or step.

    Args:
        pod_name: Name of the pod.
        run_name: Name of the ZenML run.
        pipeline_name: Name of the ZenML pipeline.
        image_name: Name of the Docker image.
        command: Command to execute the entrypoint in the pod.
        args: Arguments provided to the entrypoint command.
        service_account_name: Optional name of a service account.
            Can be used to assign certain roles to a pod, e.g., to allow it to
            run Kubernetes commands from within the cluster.

    Returns:
        Pod manifest.
    """
    manifest = {
        "apiVersion": "v1",
        "kind": "Pod",
        "metadata": {
            "name": pod_name,
            "labels": {
                "run": run_name,
                "pipeline": pipeline_name,
            },
        },
        "spec": {
            "restartPolicy": "Never",
            "containers": [
                {
                    "name": "main",
                    "image": image_name,
                    "command": command,
                    "args": args,
                    "env": [
                        {
                            "name": ENV_ZENML_ENABLE_REPO_INIT_WARNINGS,
                            "value": "False",
                        }
                    ],
                }
            ],
        },
    }
    if service_account_name is not None:
        spec = cast(Dict[str, Any], manifest["spec"])  # mypy stupid
        spec["serviceAccountName"] = service_account_name
    return manifest
build_service_account_manifest(name, namespace='default')

Build the manifest for a service account.

Parameters:

Name Type Description Default
name str

Name of the service account.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'

Returns:

Type Description
Dict[str, Any]

Manifest for a service account.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_service_account_manifest(
    name: str, namespace: str = "default"
) -> Dict[str, Any]:
    """Build the manifest for a service account.

    Args:
        name: Name of the service account.
        namespace: Kubernetes namespace. Defaults to "default".

    Returns:
        Manifest for a service account.
    """
    return {
        "apiVersion": "v1",
        "metadata": {
            "name": name,
            "namespace": namespace,
        },
    }

label_studio special

Initialization of the Label Studio integration.

LabelStudioIntegration (Integration)

Definition of Label Studio integration for ZenML.

Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
    """Definition of Label Studio integration for ZenML."""

    NAME = LABEL_STUDIO
    REQUIREMENTS = ["label-studio", "label-studio-sdk"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Label Studio integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=LABEL_STUDIO_ANNOTATOR_FLAVOR,
                source="zenml.integrations.label_studio.annotators.LabelStudioAnnotator",
                type=StackComponentType.ANNOTATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the Label Studio integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Label Studio integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=LABEL_STUDIO_ANNOTATOR_FLAVOR,
            source="zenml.integrations.label_studio.annotators.LabelStudioAnnotator",
            type=StackComponentType.ANNOTATOR,
            integration=cls.NAME,
        ),
    ]

annotators special

Initialization of the Label Studio annotators submodule.

label_studio_annotator

Implementation of the Label Studio annotation integration.

LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin) pydantic-model

Class to interact with the Label Studio annotation interface.

Attributes:

Name Type Description
port int

The port to use for the annotation interface.

api_key

The API key to use for authentication.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
    """Class to interact with the Label Studio annotation interface.

    Attributes:
        port: The port to use for the annotation interface.
        api_key: The API key to use for authentication.
    """

    port: int = DEFAULT_LABEL_STUDIO_PORT

    FLAVOR: ClassVar[str] = LABEL_STUDIO_ANNOTATOR_FLAVOR

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Validates that the stack contains a cloud artifact store.

        Returns:
            StackValidator: Validator for the stack.
        """

        def _ensure_cloud_artifact_stores(stack: Stack) -> Tuple[bool, str]:
            # For now this only works on cloud artifact stores.
            return (
                stack.artifact_store.FLAVOR
                in [
                    AZURE_ARTIFACT_STORE_FLAVOR,
                    GCP_ARTIFACT_STORE_FLAVOR,
                    S3_ARTIFACT_STORE_FLAVOR,
                ],
                "Only cloud artifact stores are currently supported",
            )

        return StackValidator(
            required_components={StackComponentType.SECRETS_MANAGER},
            custom_validation_function=_ensure_cloud_artifact_stores,
        )

    def get_url(self) -> str:
        """Gets the top-level URL of the annotation interface.

        Returns:
            The URL of the annotation interface.
        """
        return f"http://localhost:{self.port}"

    def get_url_for_dataset(self, dataset_name: str) -> str:
        """Gets the URL of the annotation interface for the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            The URL of the annotation interface.
        """
        project_id = self.get_id_from_name(dataset_name)
        return f"{self.get_url()}/projects/{project_id}/"

    def get_id_from_name(self, dataset_name: str) -> Optional[int]:
        """Gets the ID of the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            The ID of the dataset.
        """
        projects = self.get_datasets()
        for project in projects:
            if project.get_params()["title"] == dataset_name:
                return cast(int, project.get_params()["id"])
        return None

    def get_datasets(self) -> List[Any]:
        """Gets the datasets currently available for annotation.

        Returns:
            A list of datasets.
        """
        datasets = self._get_client().get_projects()
        return cast(List[Any], datasets)

    def get_dataset_names(self) -> List[str]:
        """Gets the names of the datasets.

        Returns:
            A list of dataset names.
        """
        return [
            dataset.get_params()["title"] for dataset in self.get_datasets()
        ]

    def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
        """Gets the statistics of the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            A tuple containing (labeled_task_count, unlabeled_task_count) for
                the dataset.

        Raises:
            IndexError: If the dataset does not exist.
        """
        for project in self.get_datasets():
            if dataset_name in project.get_params()["title"]:
                labeled_task_count = len(project.get_labeled_tasks())
                unlabeled_task_count = len(project.get_unlabeled_tasks())
                return (labeled_task_count, unlabeled_task_count)
        raise IndexError(
            f"Dataset {dataset_name} not found. Please use "
            f"`zenml annotator dataset list` to list all available datasets."
        )

    def launch(self, url: Optional[str]) -> None:
        """Launches the annotation interface.

        Args:
            url: The URL of the annotation interface.
        """
        if not url:
            url = self.get_url()
        if self._connection_available():
            webbrowser.open(url, new=1, autoraise=True)
        else:
            logger.warning(
                "Could not launch annotation interface"
                "because the connection could not be established."
            )

    def _get_client(self) -> Client:
        """Gets Label Studio client.

        Returns:
            Label Studio client.

        Raises:
            ValueError: when unable to access the Label Studio API key.
        """
        secret = self.get_authentication_secret(ArbitrarySecretSchema)
        if not secret:
            raise ValueError(
                f"Unable to access predefined secret '{secret}' to access Label Studio API key."
            )
        api_key = secret.content["api_key"]
        return Client(url=self.get_url(), api_key=api_key)

    def _connection_available(self) -> bool:
        """Checks if the connection to the annotation server is available.

        Returns:
            True if the connection is available, False otherwise.
        """
        try:
            result = self._get_client().check_connection()
            return result.get("status") == "UP"  # type: ignore[no-any-return]
        # TODO: [HIGH] refactor to use a more specific exception
        except Exception:
            logger.error(
                "Connection error: No connection was able to be established to the Label Studio backend."
            )
            return False

    def add_dataset(self, **kwargs: Any) -> Any:
        """Registers a dataset for annotation.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            A Label Studio Project object.

        Raises:
            ValueError: if 'dataset_name' and 'label_config' aren't provided.
        """
        dataset_name = kwargs.get("dataset_name")
        label_config = kwargs.get("label_config")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")
        elif not label_config:
            raise ValueError("`label_config` keyword argument is required.")

        return self._get_client().start_project(
            title=dataset_name,
            label_config=label_config,
        )

    def delete_dataset(self, **kwargs: Any) -> None:
        """Deletes a dataset from the annotation interface.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio
                client.

        Raises:
            NotImplementedError: If the deletion of a dataset is not supported.
        """
        raise NotImplementedError("Awaiting Label Studio release.")
        # TODO: Awaiting a new Label Studio version to be released with this method
        # ls = self._get_client()
        # dataset_name = kwargs.get("dataset_name")
        # if not dataset_name:
        #     raise ValueError("`dataset_name` keyword argument is required.")

        # dataset_id = self.get_id_from_name(dataset_name)
        # if not dataset_id:
        #     raise ValueError(
        #         f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        #     )
        # ls.delete_project(dataset_id)

    def get_dataset(self, **kwargs: Any) -> Any:
        """Gets the dataset with the given name.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The LabelStudio Dataset object (a 'Project') for the given name.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        # TODO: check for and raise error if client unavailable
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id)

    def get_converted_dataset(
        self, dataset_name: str, output_format: str
    ) -> Dict[Any, Any]:
        """Extract annotated tasks in a specific converted format.

        Args:
            dataset_name: Id of the dataset.
            output_format: Output format.

        Returns:
            A dictionary containing the converted dataset.
        """
        project = self.get_dataset(dataset_name=dataset_name)
        return project.export_tasks(export_type=output_format)  # type: ignore[no-any-return]

    def get_labeled_data(self, **kwargs: Any) -> Any:
        """Gets the labeled data for the given dataset.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The labeled data.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id).get_labeled_tasks()

    def get_unlabeled_data(self, **kwargs: str) -> Any:
        """Gets the unlabeled data for the given dataset.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The unlabeled data.

        Raises:
            ValueError: If the dataset name is not provided.
        """
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id).get_unlabeled_tasks()

    def register_dataset_for_annotation(
        self,
        config: LabelStudioDatasetRegistrationConfig,
    ) -> Any:
        """Registers a dataset for annotation.

        Args:
            config: Configuration for the dataset.

        Returns:
            A Label Studio Project object.
        """
        project_id = self.get_id_from_name(config.dataset_name)
        if project_id:
            dataset = self._get_client().get_project(project_id)
        else:
            dataset = self.add_dataset(
                dataset_name=config.dataset_name,
                label_config=config.label_config,
            )

        return dataset

    def _get_azure_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all Azure import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of Azure import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/azure?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _get_gcs_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all Google Cloud Storage import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of Google Cloud Storage import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/gcs?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _get_s3_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all AWS S3 import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of AWS S3 import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/s3?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _storage_source_already_exists(
        self, uri: str, config: LabelStudioDatasetSyncConfig, dataset: Project
    ) -> bool:
        """Returns whether a storage source already exists.

        Args:
            uri: URI of the storage source.
            config: Configuration for the dataset.
            dataset: Label Studio dataset.

        Returns:
            True if the storage source already exists, False otherwise.

        Raises:
            NotImplementedError: If the storage source type is not supported.
        """
        # TODO: check we are already connected
        dataset_id = int(dataset.get_params()["id"])
        if config.storage_type == "azure":
            storage_sources = self._get_azure_import_storage_sources(dataset_id)
        elif config.storage_type == "gcs":
            storage_sources = self._get_gcs_import_storage_sources(dataset_id)
        elif config.storage_type == "s3":
            storage_sources = self._get_s3_import_storage_sources(dataset_id)
        else:
            raise NotImplementedError(
                f"Storage type '{config.storage_type}' not implemented."
            )
        return any(
            (
                source.get("presign") == config.presign
                and source.get("bucket") == uri
                and source.get("regex_filter") == config.regex_filter
                and source.get("use_blob_urls") == config.use_blob_urls
                and source.get("title") == dataset.get_params()["title"]
                and source.get("description") == config.description
                and source.get("presign_ttl") == config.presign_ttl
                and source.get("project") == dataset_id
            )
            for source in storage_sources
        )

    def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
        """Returns the parsed Label Studio label config for a dataset.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A dictionary containing the parsed label config.

        Raises:
            ValueError: If no dataset is found for the given id.
        """
        # TODO: check if client actually is connected etc
        dataset = self._get_client().get_project(dataset_id)
        if dataset:
            return cast(Dict[str, Any], dataset.parsed_label_config)
        raise ValueError("No dataset found for the given id.")

    def connect_and_sync_external_storage(
        self,
        uri: str,
        config: LabelStudioDatasetSyncConfig,
        dataset: Project,
    ) -> Optional[Dict[str, Any]]:
        """Syncs the external storage for the given project.

        Args:
            uri: URI of the storage source.
            config: Configuration for the dataset.
            dataset: Label Studio dataset.

        Returns:
            A dictionary containing the sync result.

        Raises:
            ValueError: If the storage type is not supported.
        """
        # TODO: check if proposed storage source has differing / new data
        # if self._storage_source_already_exists(uri, config, dataset):
        #     return None

        storage_connection_args = {
            "prefix": config.prefix,
            "regex_filter": config.regex_filter,
            "use_blob_urls": config.use_blob_urls,
            "presign": config.presign,
            "presign_ttl": config.presign_ttl,
            "title": dataset.get_params()["title"],
            "description": config.description,
        }
        if config.storage_type == "azure":
            if not config.azure_account_name or not config.azure_account_key:
                logger.warning(
                    "Authentication credentials for Azure aren't fully "
                    "provided. Please update the storage synchronization "
                    "settings in the Label Studio web UI as per your needs."
                )
            storage = dataset.connect_azure_import_storage(
                container=uri,
                account_name=config.azure_account_name,
                account_key=config.azure_account_key,
                **storage_connection_args,
            )
        elif config.storage_type == "gcs":
            if not config.google_application_credentials:
                logger.warning(
                    "Authentication credentials for Google Cloud Storage "
                    "aren't fully provided. Please update the storage "
                    "synchronization settings in the Label Studio web UI as "
                    "per your needs."
                )
            storage = dataset.connect_google_import_storage(
                bucket=uri,
                google_application_credentials=config.google_application_credentials,
                **storage_connection_args,
            )
        elif config.storage_type == "s3":
            if not config.aws_access_key_id or not config.aws_secret_access_key:
                logger.warning(
                    "Authentication credentials for S3 aren't fully provided."
                    "Please update the storage synchronization settings in the "
                    " Label Studio web UI as per your needs."
                )
            storage = dataset.connect_s3_import_storage(
                bucket=uri,
                aws_access_key_id=config.aws_access_key_id,
                aws_secret_access_key=config.aws_secret_access_key,
                aws_session_token=config.aws_session_token,
                region_name=config.s3_region_name,
                s3_endpoint=config.s3_endpoint,
                **storage_connection_args,
            )
        else:
            raise ValueError(
                f"Invalid storage type. '{config.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
            )

        synced_storage = self._get_client().sync_storage(
            storage_id=storage["id"], storage_type=storage["type"]
        )
        return cast(Dict[str, Any], synced_storage)

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            io_utils.get_global_config_directory(),
            "annotators",
            str(self.uuid),
        )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.root_directory, "label_studio_daemon.pid")

    @property
    def _log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path to the daemon log file.
        """
        return os.path.join(self.root_directory, "label_studio_daemon.log")

    @property
    def is_provisioned(self) -> bool:
        """If the component provisioned resources to run locally.

        Returns:
            True if the component provisioned resources to run locally.
        """
        return fileio.exists(self.root_directory)

    @property
    def is_running(self) -> bool:
        """If the component is running locally.

        Returns:
            True if the component is running locally, False otherwise.
        """
        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            if not check_if_daemon_is_running(self._pid_file_path):
                return False
        else:
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass

        return True

    def provision(self) -> None:
        """Spins up the annotation server backend."""
        fileio.makedirs(self.root_directory)

    def deprovision(self) -> None:
        """Spins down the annotation server backend."""
        if fileio.exists(self._log_file):
            fileio.remove(self._log_file)

    def resume(self) -> None:
        """Resumes the annotation interface."""
        if self.is_running:
            logger.info("Local kubeflow pipelines deployment already running.")
            return

        self.start_annotator_daemon()

    def suspend(self) -> None:
        """Suspends the annotation interface."""
        if not self.is_running:
            logger.info("Local annotation server is not running.")
            return

        self.stop_annotator_daemon()

    def start_annotator_daemon(self) -> None:
        """Starts the annotation server backend.

        Raises:
            ProvisioningError: If the annotation server backend is already
                running or the port is already occupied.
        """
        command = [
            "label-studio",
            "start",
            "--no-browser",
            "--port",
            f"{self.port}",
        ]

        if sys.platform == "win32":
            logger.warning(
                "Daemon functionality not supported on Windows. "
                "In order to access the Label Studio server locally, "
                "please run '%s' in a separate command line shell.",
                self.port,
                " ".join(command),
            )
        elif not networking_utils.port_available(self.port):
            raise ProvisioningError(
                f"Unable to port-forward Label Studio to local "
                f"port {self.port} because the port is occupied. In order to "
                f"access Label Studio locally, please "
                f"change the configuration to use an available "
                f"port or stop the other process currently using the port."
            )
        else:
            from zenml.utils import daemon

            def _daemon_function() -> None:
                """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
                subprocess.check_call(command)

            daemon.run_as_daemon(
                _daemon_function,
                pid_file=self._pid_file_path,
                log_file=self._log_file,
            )
            logger.info(
                "Started Label Studio daemon (check the daemon"
                "logs at `%s` in case you're not able to access the annotation "
                f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
                self._log_file,
            )

    def stop_annotator_daemon(self) -> None:
        """Stops the annotation server backend."""
        if fileio.exists(self._pid_file_path):
            if sys.platform == "win32":
                # Daemon functionality is not supported on Windows, so the PID
                # file won't exist. This if clause exists just for mypy to not
                # complain about missing functions
                pass
            else:
                from zenml.utils import daemon

                daemon.stop_daemon(self._pid_file_path)
                fileio.remove(self._pid_file_path)
is_provisioned: bool property readonly

If the component provisioned resources to run locally.

Returns:

Type Description
bool

True if the component provisioned resources to run locally.

is_running: bool property readonly

If the component is running locally.

Returns:

Type Description
bool

True if the component is running locally, False otherwise.

root_directory: str property readonly

Returns path to the root directory.

Returns:

Type Description
str

Path to the root directory.

validator: Optional[StackValidator] property readonly

Validates that the stack contains a cloud artifact store.

Returns:

Type Description
StackValidator

Validator for the stack.

add_dataset(self, **kwargs)

Registers a dataset for annotation.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

A Label Studio Project object.

Exceptions:

Type Description
ValueError

if 'dataset_name' and 'label_config' aren't provided.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
    """Registers a dataset for annotation.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        A Label Studio Project object.

    Raises:
        ValueError: if 'dataset_name' and 'label_config' aren't provided.
    """
    dataset_name = kwargs.get("dataset_name")
    label_config = kwargs.get("label_config")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")
    elif not label_config:
        raise ValueError("`label_config` keyword argument is required.")

    return self._get_client().start_project(
        title=dataset_name,
        label_config=label_config,
    )
connect_and_sync_external_storage(self, uri, config, dataset)

Syncs the external storage for the given project.

Parameters:

Name Type Description Default
uri str

URI of the storage source.

required
config LabelStudioDatasetSyncConfig

Configuration for the dataset.

required
dataset Project

Label Studio dataset.

required

Returns:

Type Description
Optional[Dict[str, Any]]

A dictionary containing the sync result.

Exceptions:

Type Description
ValueError

If the storage type is not supported.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
    self,
    uri: str,
    config: LabelStudioDatasetSyncConfig,
    dataset: Project,
) -> Optional[Dict[str, Any]]:
    """Syncs the external storage for the given project.

    Args:
        uri: URI of the storage source.
        config: Configuration for the dataset.
        dataset: Label Studio dataset.

    Returns:
        A dictionary containing the sync result.

    Raises:
        ValueError: If the storage type is not supported.
    """
    # TODO: check if proposed storage source has differing / new data
    # if self._storage_source_already_exists(uri, config, dataset):
    #     return None

    storage_connection_args = {
        "prefix": config.prefix,
        "regex_filter": config.regex_filter,
        "use_blob_urls": config.use_blob_urls,
        "presign": config.presign,
        "presign_ttl": config.presign_ttl,
        "title": dataset.get_params()["title"],
        "description": config.description,
    }
    if config.storage_type == "azure":
        if not config.azure_account_name or not config.azure_account_key:
            logger.warning(
                "Authentication credentials for Azure aren't fully "
                "provided. Please update the storage synchronization "
                "settings in the Label Studio web UI as per your needs."
            )
        storage = dataset.connect_azure_import_storage(
            container=uri,
            account_name=config.azure_account_name,
            account_key=config.azure_account_key,
            **storage_connection_args,
        )
    elif config.storage_type == "gcs":
        if not config.google_application_credentials:
            logger.warning(
                "Authentication credentials for Google Cloud Storage "
                "aren't fully provided. Please update the storage "
                "synchronization settings in the Label Studio web UI as "
                "per your needs."
            )
        storage = dataset.connect_google_import_storage(
            bucket=uri,
            google_application_credentials=config.google_application_credentials,
            **storage_connection_args,
        )
    elif config.storage_type == "s3":
        if not config.aws_access_key_id or not config.aws_secret_access_key:
            logger.warning(
                "Authentication credentials for S3 aren't fully provided."
                "Please update the storage synchronization settings in the "
                " Label Studio web UI as per your needs."
            )
        storage = dataset.connect_s3_import_storage(
            bucket=uri,
            aws_access_key_id=config.aws_access_key_id,
            aws_secret_access_key=config.aws_secret_access_key,
            aws_session_token=config.aws_session_token,
            region_name=config.s3_region_name,
            s3_endpoint=config.s3_endpoint,
            **storage_connection_args,
        )
    else:
        raise ValueError(
            f"Invalid storage type. '{config.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
        )

    synced_storage = self._get_client().sync_storage(
        storage_id=storage["id"], storage_type=storage["type"]
    )
    return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)

Deletes a dataset from the annotation interface.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Exceptions:

Type Description
NotImplementedError

If the deletion of a dataset is not supported.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
    """Deletes a dataset from the annotation interface.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio
            client.

    Raises:
        NotImplementedError: If the deletion of a dataset is not supported.
    """
    raise NotImplementedError("Awaiting Label Studio release.")
    # TODO: Awaiting a new Label Studio version to be released with this method
    # ls = self._get_client()
    # dataset_name = kwargs.get("dataset_name")
    # if not dataset_name:
    #     raise ValueError("`dataset_name` keyword argument is required.")

    # dataset_id = self.get_id_from_name(dataset_name)
    # if not dataset_id:
    #     raise ValueError(
    #         f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
    #     )
    # ls.delete_project(dataset_id)
deprovision(self)

Spins down the annotation server backend.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def deprovision(self) -> None:
    """Spins down the annotation server backend."""
    if fileio.exists(self._log_file):
        fileio.remove(self._log_file)
get_converted_dataset(self, dataset_name, output_format)

Extract annotated tasks in a specific converted format.

Parameters:

Name Type Description Default
dataset_name str

Id of the dataset.

required
output_format str

Output format.

required

Returns:

Type Description
Dict[Any, Any]

A dictionary containing the converted dataset.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
    self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
    """Extract annotated tasks in a specific converted format.

    Args:
        dataset_name: Id of the dataset.
        output_format: Output format.

    Returns:
        A dictionary containing the converted dataset.
    """
    project = self.get_dataset(dataset_name=dataset_name)
    return project.export_tasks(export_type=output_format)  # type: ignore[no-any-return]
get_dataset(self, **kwargs)

Gets the dataset with the given name.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The LabelStudio Dataset object (a 'Project') for the given name.

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
    """Gets the dataset with the given name.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The LabelStudio Dataset object (a 'Project') for the given name.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    # TODO: check for and raise error if client unavailable
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id)
get_dataset_names(self)

Gets the names of the datasets.

Returns:

Type Description
List[str]

A list of dataset names.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
    """Gets the names of the datasets.

    Returns:
        A list of dataset names.
    """
    return [
        dataset.get_params()["title"] for dataset in self.get_datasets()
    ]
get_dataset_stats(self, dataset_name)

Gets the statistics of the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
Tuple[int, int]

A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset.

Exceptions:

Type Description
IndexError

If the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
    """Gets the statistics of the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        A tuple containing (labeled_task_count, unlabeled_task_count) for
            the dataset.

    Raises:
        IndexError: If the dataset does not exist.
    """
    for project in self.get_datasets():
        if dataset_name in project.get_params()["title"]:
            labeled_task_count = len(project.get_labeled_tasks())
            unlabeled_task_count = len(project.get_unlabeled_tasks())
            return (labeled_task_count, unlabeled_task_count)
    raise IndexError(
        f"Dataset {dataset_name} not found. Please use "
        f"`zenml annotator dataset list` to list all available datasets."
    )
get_datasets(self)

Gets the datasets currently available for annotation.

Returns:

Type Description
List[Any]

A list of datasets.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
    """Gets the datasets currently available for annotation.

    Returns:
        A list of datasets.
    """
    datasets = self._get_client().get_projects()
    return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)

Gets the ID of the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
Optional[int]

The ID of the dataset.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
    """Gets the ID of the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        The ID of the dataset.
    """
    projects = self.get_datasets()
    for project in projects:
        if project.get_params()["title"] == dataset_name:
            return cast(int, project.get_params()["id"])
    return None
get_labeled_data(self, **kwargs)

Gets the labeled data for the given dataset.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The labeled data.

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
    """Gets the labeled data for the given dataset.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The labeled data.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)

Returns the parsed Label Studio label config for a dataset.

Parameters:

Name Type Description Default
dataset_id int

Id of the dataset.

required

Returns:

Type Description
Dict[str, Any]

A dictionary containing the parsed label config.

Exceptions:

Type Description
ValueError

If no dataset is found for the given id.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
    """Returns the parsed Label Studio label config for a dataset.

    Args:
        dataset_id: Id of the dataset.

    Returns:
        A dictionary containing the parsed label config.

    Raises:
        ValueError: If no dataset is found for the given id.
    """
    # TODO: check if client actually is connected etc
    dataset = self._get_client().get_project(dataset_id)
    if dataset:
        return cast(Dict[str, Any], dataset.parsed_label_config)
    raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)

Gets the unlabeled data for the given dataset.

Parameters:

Name Type Description Default
**kwargs str

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The unlabeled data.

Exceptions:

Type Description
ValueError

If the dataset name is not provided.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
    """Gets the unlabeled data for the given dataset.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The unlabeled data.

    Raises:
        ValueError: If the dataset name is not provided.
    """
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)

Gets the top-level URL of the annotation interface.

Returns:

Type Description
str

The URL of the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
    """Gets the top-level URL of the annotation interface.

    Returns:
        The URL of the annotation interface.
    """
    return f"http://localhost:{self.port}"
get_url_for_dataset(self, dataset_name)

Gets the URL of the annotation interface for the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
str

The URL of the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
    """Gets the URL of the annotation interface for the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        The URL of the annotation interface.
    """
    project_id = self.get_id_from_name(dataset_name)
    return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)

Launches the annotation interface.

Parameters:

Name Type Description Default
url Optional[str]

The URL of the annotation interface.

required
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
    """Launches the annotation interface.

    Args:
        url: The URL of the annotation interface.
    """
    if not url:
        url = self.get_url()
    if self._connection_available():
        webbrowser.open(url, new=1, autoraise=True)
    else:
        logger.warning(
            "Could not launch annotation interface"
            "because the connection could not be established."
        )
provision(self)

Spins up the annotation server backend.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def provision(self) -> None:
    """Spins up the annotation server backend."""
    fileio.makedirs(self.root_directory)
register_dataset_for_annotation(self, config)

Registers a dataset for annotation.

Parameters:

Name Type Description Default
config LabelStudioDatasetRegistrationConfig

Configuration for the dataset.

required

Returns:

Type Description
Any

A Label Studio Project object.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
    self,
    config: LabelStudioDatasetRegistrationConfig,
) -> Any:
    """Registers a dataset for annotation.

    Args:
        config: Configuration for the dataset.

    Returns:
        A Label Studio Project object.
    """
    project_id = self.get_id_from_name(config.dataset_name)
    if project_id:
        dataset = self._get_client().get_project(project_id)
    else:
        dataset = self.add_dataset(
            dataset_name=config.dataset_name,
            label_config=config.label_config,
        )

    return dataset
resume(self)

Resumes the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def resume(self) -> None:
    """Resumes the annotation interface."""
    if self.is_running:
        logger.info("Local kubeflow pipelines deployment already running.")
        return

    self.start_annotator_daemon()
start_annotator_daemon(self)

Starts the annotation server backend.

Exceptions:

Type Description
ProvisioningError

If the annotation server backend is already running or the port is already occupied.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def start_annotator_daemon(self) -> None:
    """Starts the annotation server backend.

    Raises:
        ProvisioningError: If the annotation server backend is already
            running or the port is already occupied.
    """
    command = [
        "label-studio",
        "start",
        "--no-browser",
        "--port",
        f"{self.port}",
    ]

    if sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Label Studio server locally, "
            "please run '%s' in a separate command line shell.",
            self.port,
            " ".join(command),
        )
    elif not networking_utils.port_available(self.port):
        raise ProvisioningError(
            f"Unable to port-forward Label Studio to local "
            f"port {self.port} because the port is occupied. In order to "
            f"access Label Studio locally, please "
            f"change the configuration to use an available "
            f"port or stop the other process currently using the port."
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=self._pid_file_path,
            log_file=self._log_file,
        )
        logger.info(
            "Started Label Studio daemon (check the daemon"
            "logs at `%s` in case you're not able to access the annotation "
            f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
            self._log_file,
        )
stop_annotator_daemon(self)

Stops the annotation server backend.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def stop_annotator_daemon(self) -> None:
    """Stops the annotation server backend."""
    if fileio.exists(self._pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path)
            fileio.remove(self._pid_file_path)
suspend(self)

Suspends the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def suspend(self) -> None:
    """Suspends the annotation interface."""
    if not self.is_running:
        logger.info("Local annotation server is not running.")
        return

    self.stop_annotator_daemon()

label_config_generators special

Initialization of the Label Studio config generators submodule.

label_config_generators

Implementation of label config generators for Label Studio.

generate_basic_object_detection_bounding_boxes_label_config(labels)

Generates a Label Studio config for object detection with bounding boxes.

This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio config for object detection with bounding boxes.

    This is based on the basic config example shown at
    https://labelstud.io/templates/image_bbox.html.

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES

    label_config_start = """<View>
    <Image name="image" value="$image"/>
    <RectangleLabels name="label" toName="image">
    """
    label_config_choices = "".join(
        f"<Label value='{label}' />\n" for label in labels
    )
    label_config_end = "</RectangleLabels>\n</View>"
    label_config = label_config_start + label_config_choices + label_config_end

    return (
        label_config,
        label_config_type,
    )
generate_image_classification_label_config(labels)

Generates a Label Studio label config for image classification.

This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio label config for image classification.

    This is based on the basic config example shown at
    https://labelstud.io/templates/image_classification.html.

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION

    label_config_start = """<View>
    <Image name="image" value="$image"/>
    <Choices name="choice" toName="image">
    """
    label_config_choices = "".join(
        f"<Choice value='{label}' />\n" for label in labels
    )
    label_config_end = "</Choices>\n</View>"

    label_config = label_config_start + label_config_choices + label_config_end
    return (
        label_config,
        label_config_type,
    )

label_studio_utils

Utility functions for the Label Studio annotator integration.

convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)

Converts a list of predictions from local file references to task id.

Parameters:

Name Type Description Default
preds List[Dict[str, Any]]

List of predictions.

required
tasks List[Dict[str, Any]]

List of tasks.

required
filename_reference str

Name of the file reference in the predictions.

required
storage_type str

Storage type of the predictions.

required

Returns:

Type Description
List[Dict[str, Any]]

List of predictions using task ids as reference.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
    preds: List[Dict[str, Any]],
    tasks: List[Dict[str, Any]],
    filename_reference: str,
    storage_type: str,
) -> List[Dict[str, Any]]:
    """Converts a list of predictions from local file references to task id.

    Args:
        preds: List of predictions.
        tasks: List of tasks.
        filename_reference: Name of the file reference in the predictions.
        storage_type: Storage type of the predictions.

    Returns:
        List of predictions using task ids as reference.
    """
    filename_id_mapping = {
        os.path.basename(urlparse(task["data"][filename_reference]).path): task[
            "id"
        ]
        for task in tasks
    }
    # GCS and S3 URL encodes filenames containing spaces, requiring this
    # separate encoding step
    if storage_type in {"gcs", "s3"}:
        preds = [
            {"filename": quote(pred["filename"]), "result": pred["result"]}
            for pred in preds
        ]
    return [
        {
            "task": int(
                filename_id_mapping[os.path.basename(pred["filename"])]
            ),
            "result": pred["result"],
        }
        for pred in preds
    ]
get_file_extension(path_str)

Return the file extension of the given filename.

Parameters:

Name Type Description Default
path_str str

Path to the file.

required

Returns:

Type Description
str

File extension.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
    """Return the file extension of the given filename.

    Args:
        path_str: Path to the file.

    Returns:
        File extension.
    """
    return os.path.splitext(urlparse(path_str).path)[1]
is_azure_url(url)

Return whether the given URL is an Azure URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an Azure URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
    """Return whether the given URL is an Azure URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an Azure URL, False otherwise.
    """
    return "blob.core.windows.net" in urlparse(url).netloc
is_gcs_url(url)

Return whether the given URL is an GCS URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an GCS URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
    """Return whether the given URL is an GCS URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an GCS URL, False otherwise.
    """
    return "storage.googleapis.com" in urlparse(url).netloc
is_s3_url(url)

Return whether the given URL is an S3 URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an S3 URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
    """Return whether the given URL is an S3 URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an S3 URL, False otherwise.
    """
    return "s3.amazonaws" in urlparse(url).netloc

steps special

Standard steps to be used with the Label Studio annotator integration.

label_studio_standard_steps

Implementation of standard steps for the Label Studio annotator integration.

LabelStudioDatasetRegistrationConfig (BaseStepConfig) pydantic-model

Step config when registering a dataset with Label Studio.

Attributes:

Name Type Description
label_config str

The label config to use for the annotation interface.

dataset_name str

Name of the dataset to register.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationConfig(BaseStepConfig):
    """Step config when registering a dataset with Label Studio.

    Attributes:
        label_config: The label config to use for the annotation interface.
        dataset_name: Name of the dataset to register.
    """

    label_config: str
    dataset_name: str
LabelStudioDatasetSyncConfig (BaseStepConfig) pydantic-model

Step config when syncing data to Label Studio.

Attributes:

Name Type Description
storage_type str

The type of storage to sync to.

label_config_type str

The type of label config to use.

prefix Optional[str]

Specify the prefix within the cloud store to import your data from.

regex_filter Optional[str]

Specify a regex filter to filter the files to import.

use_blob_urls Optional[bool]

Specify whether your data is raw image or video data, or JSON tasks.

presign Optional[bool]

Specify whether or not to create presigned URLs.

presign_ttl Optional[int]

Specify how long to keep presigned URLs active.

description Optional[str]

Specify a description for the dataset.

azure_account_name Optional[str]

Specify the Azure account name to use for the storage.

azure_account_key Optional[str]

Specify the Azure account key to use for the storage.

google_application_credentials Optional[str]

Specify the Google application credentials to use for the storage.

aws_access_key_id Optional[str]

Specify the AWS access key ID to use for the storage.

aws_secret_access_key Optional[str]

Specify the AWS secret access key to use for the storage.

aws_session_token Optional[str]

Specify the AWS session token to use for the storage.

s3_region_name Optional[str]

Specify the S3 region name to use for the storage.

s3_endpoint Optional[str]

Specify the S3 endpoint to use for the storage.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncConfig(BaseStepConfig):
    """Step config when syncing data to Label Studio.

    Attributes:
        storage_type: The type of storage to sync to.
        label_config_type: The type of label config to use.

        prefix: Specify the prefix within the cloud store to import your data
            from.
        regex_filter: Specify a regex filter to filter the files to import.
        use_blob_urls: Specify whether your data is raw image or video data, or
            JSON tasks.
        presign: Specify whether or not to create presigned URLs.
        presign_ttl: Specify how long to keep presigned URLs active.
        description: Specify a description for the dataset.

        azure_account_name: Specify the Azure account name to use for the
            storage.
        azure_account_key: Specify the Azure account key to use for the
            storage.
        google_application_credentials: Specify the Google application
            credentials to use for the storage.
        aws_access_key_id: Specify the AWS access key ID to use for the
            storage.
        aws_secret_access_key: Specify the AWS secret access key to use for the
            storage.
        aws_session_token: Specify the AWS session token to use for the
            storage.
        s3_region_name: Specify the S3 region name to use for the storage.
        s3_endpoint: Specify the S3 endpoint to use for the storage.
    """

    storage_type: str
    label_config_type: str

    prefix: Optional[str] = None
    regex_filter: Optional[str] = ".*"
    use_blob_urls: Optional[bool] = True
    presign: Optional[bool] = True
    presign_ttl: Optional[int] = 1
    description: Optional[str] = ""

    # credentials specific to the main cloud providers
    azure_account_name: Optional[str]
    azure_account_key: Optional[str]
    google_application_credentials: Optional[str]
    aws_access_key_id: Optional[str]
    aws_secret_access_key: Optional[str]
    aws_session_token: Optional[str]
    s3_region_name: Optional[str]
    s3_endpoint: Optional[str]
get_labeled_data (BaseStep)

Gets labeled data from the dataset.

Parameters:

Name Type Description Default
dataset_name

Name of the dataset.

required
context

The StepContext.

required

Returns:

Type Description

List of labeled data.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

entrypoint(dataset_name, context) staticmethod

Gets labeled data from the dataset.

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset.

required
context StepContext

The StepContext.

required

Returns:

Type Description
List

List of labeled data.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str, context: StepContext) -> List:  # type: ignore[type-arg]
    """Gets labeled data from the dataset.

    Args:
        dataset_name: Name of the dataset.
        context: The StepContext.

    Returns:
        List of labeled data.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    # TODO [MEDIUM]: have this check for new data *since the last time this step ran*
    annotator = context.stack.annotator  # type: ignore[union-attr]
    if not annotator:
        raise StackComponentInterfaceError("No active annotator.")
    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )
    if annotator._connection_available():
        dataset = annotator.get_dataset(dataset_name=dataset_name)
        return dataset.get_labeled_tasks()  # type: ignore[no-any-return]

    raise StackComponentInterfaceError(
        "Unable to connect to annotator stack component."
    )
get_or_create_dataset (BaseStep)

Gets preexisting dataset or creates a new one.

Parameters:

Name Type Description Default
config

Step config.

required
context

Step context.

required

Returns:

Type Description

The dataset name.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

CONFIG_CLASS (BaseStepConfig) pydantic-model

Step config when registering a dataset with Label Studio.

Attributes:

Name Type Description
label_config str

The label config to use for the annotation interface.

dataset_name str

Name of the dataset to register.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationConfig(BaseStepConfig):
    """Step config when registering a dataset with Label Studio.

    Attributes:
        label_config: The label config to use for the annotation interface.
        dataset_name: Name of the dataset to register.
    """

    label_config: str
    dataset_name: str
entrypoint(config, context) staticmethod

Gets preexisting dataset or creates a new one.

Parameters:

Name Type Description Default
config LabelStudioDatasetRegistrationConfig

Step config.

required
context StepContext

Step context.

required

Returns:

Type Description
str

The dataset name.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
    config: LabelStudioDatasetRegistrationConfig,
    context: StepContext,
) -> str:
    """Gets preexisting dataset or creates a new one.

    Args:
        config: Step config.
        context: Step context.

    Returns:
        The dataset name.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    annotator = context.stack.annotator  # type: ignore[union-attr]
    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )

    if annotator and annotator._connection_available():
        for dataset in annotator.get_datasets():
            if dataset.get_params()["title"] == config.dataset_name:
                return cast(str, dataset.get_params()["title"])

        dataset = annotator.register_dataset_for_annotation(config)
        return cast(str, dataset.get_params()["title"])

    raise StackComponentInterfaceError("No active annotator.")
    # if annotator and annotator._connection_available():
    #     preexisting_dataset_list = [
    #         dataset
    #         for dataset in annotator.get_datasets()
    #         if dataset.get_params()["title"] == config.dataset_name
    #     ]
    #     if (
    #         not preexisting_dataset_list
    #         and annotator
    #         and annotator._connection_available()
    #     ):
    #         registered_dataset = annotator.register_dataset_for_annotation(
    #             config
    #         )
    #     elif preexisting_dataset_list:
    #         return cast(str, preexisting_dataset_list[0].get_params()["title"])
    #     else:
    #         raise StackComponentInterfaceError("No active annotator.")

    #     return cast(str, registered_dataset.get_params()["title"])
    # else:
    #     raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (BaseStep)

Syncs new data to Label Studio.

Parameters:

Name Type Description Default
uri

The URI of the data to sync.

required
dataset_name

The name of the dataset to sync to.

required
predictions

The predictions to sync.

required
config

The config for the sync.

required
context

The StepContext.

required

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

ValueError

if you are trying to sync from outside ZenML.

StackComponentInterfaceError

If no active annotator could be found.

CONFIG_CLASS (BaseStepConfig) pydantic-model

Step config when syncing data to Label Studio.

Attributes:

Name Type Description
storage_type str

The type of storage to sync to.

label_config_type str

The type of label config to use.

prefix Optional[str]

Specify the prefix within the cloud store to import your data from.

regex_filter Optional[str]

Specify a regex filter to filter the files to import.

use_blob_urls Optional[bool]

Specify whether your data is raw image or video data, or JSON tasks.

presign Optional[bool]

Specify whether or not to create presigned URLs.

presign_ttl Optional[int]

Specify how long to keep presigned URLs active.

description Optional[str]

Specify a description for the dataset.

azure_account_name Optional[str]

Specify the Azure account name to use for the storage.

azure_account_key Optional[str]

Specify the Azure account key to use for the storage.

google_application_credentials Optional[str]

Specify the Google application credentials to use for the storage.

aws_access_key_id Optional[str]

Specify the AWS access key ID to use for the storage.

aws_secret_access_key Optional[str]

Specify the AWS secret access key to use for the storage.

aws_session_token Optional[str]

Specify the AWS session token to use for the storage.

s3_region_name Optional[str]

Specify the S3 region name to use for the storage.

s3_endpoint Optional[str]

Specify the S3 endpoint to use for the storage.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncConfig(BaseStepConfig):
    """Step config when syncing data to Label Studio.

    Attributes:
        storage_type: The type of storage to sync to.
        label_config_type: The type of label config to use.

        prefix: Specify the prefix within the cloud store to import your data
            from.
        regex_filter: Specify a regex filter to filter the files to import.
        use_blob_urls: Specify whether your data is raw image or video data, or
            JSON tasks.
        presign: Specify whether or not to create presigned URLs.
        presign_ttl: Specify how long to keep presigned URLs active.
        description: Specify a description for the dataset.

        azure_account_name: Specify the Azure account name to use for the
            storage.
        azure_account_key: Specify the Azure account key to use for the
            storage.
        google_application_credentials: Specify the Google application
            credentials to use for the storage.
        aws_access_key_id: Specify the AWS access key ID to use for the
            storage.
        aws_secret_access_key: Specify the AWS secret access key to use for the
            storage.
        aws_session_token: Specify the AWS session token to use for the
            storage.
        s3_region_name: Specify the S3 region name to use for the storage.
        s3_endpoint: Specify the S3 endpoint to use for the storage.
    """

    storage_type: str
    label_config_type: str

    prefix: Optional[str] = None
    regex_filter: Optional[str] = ".*"
    use_blob_urls: Optional[bool] = True
    presign: Optional[bool] = True
    presign_ttl: Optional[int] = 1
    description: Optional[str] = ""

    # credentials specific to the main cloud providers
    azure_account_name: Optional[str]
    azure_account_key: Optional[str]
    google_application_credentials: Optional[str]
    aws_access_key_id: Optional[str]
    aws_secret_access_key: Optional[str]
    aws_session_token: Optional[str]
    s3_region_name: Optional[str]
    s3_endpoint: Optional[str]
entrypoint(uri, dataset_name, predictions, config, context) staticmethod

Syncs new data to Label Studio.

Parameters:

Name Type Description Default
uri str

The URI of the data to sync.

required
dataset_name str

The name of the dataset to sync to.

required
predictions List[Dict[str, Any]]

The predictions to sync.

required
config LabelStudioDatasetSyncConfig

The config for the sync.

required
context StepContext

The StepContext.

required

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

ValueError

if you are trying to sync from outside ZenML.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
    uri: str,
    dataset_name: str,
    predictions: List[Dict[str, Any]],
    config: LabelStudioDatasetSyncConfig,
    context: StepContext,
) -> None:
    """Syncs new data to Label Studio.

    Args:
        uri: The URI of the data to sync.
        dataset_name: The name of the dataset to sync to.
        predictions: The predictions to sync.
        config: The config for the sync.
        context: The StepContext.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        ValueError: if you are trying to sync from outside ZenML.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    annotator = context.stack.annotator  # type: ignore[union-attr]
    artifact_store = context.stack.artifact_store  # type: ignore[union-attr]
    secrets_manager = context.stack.secrets_manager  # type: ignore[union-attr]
    if not annotator or not artifact_store or not secrets_manager:
        raise StackComponentInterfaceError(
            "An active annotator, artifact store and secrets manager are required to run this step."
        )

    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )

    # TODO: check that annotator is connected before querying it
    dataset = annotator.get_dataset(dataset_name=dataset_name)
    if not uri.startswith(artifact_store.path):
        raise ValueError(
            "ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
        )

    # removes the initial forward slash from the prefix attribute by slicing
    config.prefix = urlparse(uri).path.lstrip("/")
    base_uri = urlparse(uri).netloc

    # gets the secret used for authentication
    authentication_secret_name = artifact_store.authentication_secret  # type: ignore[union-attr]
    if config.storage_type == "azure":
        config.azure_account_name = secrets_manager.get_secret(  # type: ignore[union-attr]
            authentication_secret_name
        ).account_name
        config.azure_account_key = secrets_manager.get_secret(  # type: ignore[union-attr]
            authentication_secret_name
        ).account_key
    elif config.storage_type == "gcs":
        config.google_application_credentials = secrets_manager.get_secret(  # type: ignore[union-attr]
            authentication_secret_name
        ).token
    elif config.storage_type == "s3":
        config.aws_access_key_id = secrets_manager.get_secret(  # type: ignore[union-attr]
            LABEL_STUDIO_AWS_SECRET_NAME
        ).aws_access_key_id
        config.aws_secret_access_key = secrets_manager.get_secret(  # type: ignore[union-attr]
            LABEL_STUDIO_AWS_SECRET_NAME
        ).aws_secret_access_key
        config.aws_session_token = secrets_manager.get_secret(  # type: ignore[union-attr]
            LABEL_STUDIO_AWS_SECRET_NAME
        ).aws_session_token

    if annotator and annotator._connection_available():
        # TODO: get existing (CHECK!) or create the sync connection
        annotator.connect_and_sync_external_storage(
            uri=base_uri,
            config=config,
            dataset=dataset,
        )
        if predictions:
            filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
                config.label_config_type
            ]
            preds_with_task_ids = convert_pred_filenames_to_task_ids(
                predictions,
                dataset.tasks,
                filename_reference,
                config.storage_type,
            )
            # TODO: filter out any predictions that exist + have already been
            # made (maybe?). Only pass in preds for tasks without pre-annotations.
            dataset.create_predictions(preds_with_task_ids)
    else:
        raise StackComponentInterfaceError("No active annotator.")

lightgbm special

Initialization of the LightGBM integration.

LightGBMIntegration (Integration)

Definition of lightgbm integration for ZenML.

Source code in zenml/integrations/lightgbm/__init__.py
class LightGBMIntegration(Integration):
    """Definition of lightgbm integration for ZenML."""

    NAME = LIGHTGBM
    REQUIREMENTS = ["lightgbm>=1.0.0"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.lightgbm import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/lightgbm/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.lightgbm import materializers  # noqa

materializers special

Initialization of the Neural Prophet materializer.

lightgbm_booster_materializer

Implementation of the LightGBM booster materializer.

LightGBMBoosterMaterializer (BaseMaterializer)

Materializer to read data to and from lightgbm.Booster.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
class LightGBMBoosterMaterializer(BaseMaterializer):
    """Materializer to read data to and from lightgbm.Booster."""

    ASSOCIATED_TYPES = (lgb.Booster,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
        """Reads a lightgbm Booster model from a serialized JSON file.

        Args:
            data_type: A lightgbm Booster type.

        Returns:
            A lightgbm Booster object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        booster = lgb.Booster(model_file=temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)
        return booster

    def handle_return(self, booster: lgb.Booster) -> None:
        """Creates a JSON serialization for a lightgbm Booster model.

        Args:
            booster: A lightgbm Booster model.
        """
        super().handle_return(booster)

        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".txt", delete=False
        ) as f:
            booster.save_model(f.name)
            # Copy it into artifact store
            fileio.copy(f.name, filepath)

        # Close and remove the temporary file
        f.close()
        fileio.remove(f.name)
handle_input(self, data_type)

Reads a lightgbm Booster model from a serialized JSON file.

Parameters:

Name Type Description Default
data_type Type[Any]

A lightgbm Booster type.

required

Returns:

Type Description
Booster

A lightgbm Booster object.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
    """Reads a lightgbm Booster model from a serialized JSON file.

    Args:
        data_type: A lightgbm Booster type.

    Returns:
        A lightgbm Booster object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    booster = lgb.Booster(model_file=temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)
    return booster
handle_return(self, booster)

Creates a JSON serialization for a lightgbm Booster model.

Parameters:

Name Type Description Default
booster Booster

A lightgbm Booster model.

required
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_return(self, booster: lgb.Booster) -> None:
    """Creates a JSON serialization for a lightgbm Booster model.

    Args:
        booster: A lightgbm Booster model.
    """
    super().handle_return(booster)

    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".txt", delete=False
    ) as f:
        booster.save_model(f.name)
        # Copy it into artifact store
        fileio.copy(f.name, filepath)

    # Close and remove the temporary file
    f.close()
    fileio.remove(f.name)
lightgbm_dataset_materializer

Implementation of the LightGBM materializer.

LightGBMDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from lightgbm.Dataset.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
class LightGBMDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from lightgbm.Dataset."""

    ASSOCIATED_TYPES = (lgb.Dataset,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
        """Reads a lightgbm.Dataset binary file and loads it.

        Args:
            data_type: A lightgbm.Dataset type.

        Returns:
            A lightgbm.Dataset object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        matrix = lgb.Dataset(temp_file, free_raw_data=False)

        # No clean up this time because matrix is lazy loaded
        return matrix

    def handle_return(self, matrix: lgb.Dataset) -> None:
        """Creates a binary serialization for a lightgbm.Dataset object.

        Args:
            matrix: A lightgbm.Dataset object.
        """
        super().handle_return(matrix)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
        matrix.save_binary(temp_file)

        # Copy it into artifact store
        fileio.copy(temp_file, filepath)
        fileio.rmtree(temp_dir)
handle_input(self, data_type)

Reads a lightgbm.Dataset binary file and loads it.

Parameters:

Name Type Description Default
data_type Type[Any]

A lightgbm.Dataset type.

required

Returns:

Type Description
Dataset

A lightgbm.Dataset object.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
    """Reads a lightgbm.Dataset binary file and loads it.

    Args:
        data_type: A lightgbm.Dataset type.

    Returns:
        A lightgbm.Dataset object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    matrix = lgb.Dataset(temp_file, free_raw_data=False)

    # No clean up this time because matrix is lazy loaded
    return matrix
handle_return(self, matrix)

Creates a binary serialization for a lightgbm.Dataset object.

Parameters:

Name Type Description Default
matrix Dataset

A lightgbm.Dataset object.

required
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_return(self, matrix: lgb.Dataset) -> None:
    """Creates a binary serialization for a lightgbm.Dataset object.

    Args:
        matrix: A lightgbm.Dataset object.
    """
    super().handle_return(matrix)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
    matrix.save_binary(temp_file)

    # Copy it into artifact store
    fileio.copy(temp_file, filepath)
    fileio.rmtree(temp_dir)

mlflow special

Initialization for the ZenML MLflow integration.

The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.

MlflowIntegration (Integration)

Definition of MLflow integration for ZenML.

Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
    """Definition of MLflow integration for ZenML."""

    NAME = MLFLOW
    REQUIREMENTS = [
        "mlflow>=1.2.0,<1.26.0",
        "mlserver>=0.5.3",
        "mlserver-mlflow>=0.5.3",
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the MLflow integration."""
        from zenml.integrations.mlflow import services  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the MLflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=MLFLOW_MODEL_DEPLOYER_FLAVOR,
                source="zenml.integrations.mlflow.model_deployers.MLFlowModelDeployer",
                type=StackComponentType.MODEL_DEPLOYER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR,
                source="zenml.integrations.mlflow.experiment_trackers.MLFlowExperimentTracker",
                type=StackComponentType.EXPERIMENT_TRACKER,
                integration=cls.NAME,
            ),
        ]
activate() classmethod

Activate the MLflow integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the MLflow integration."""
    from zenml.integrations.mlflow import services  # noqa
flavors() classmethod

Declare the stack component flavors for the MLflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the MLflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=MLFLOW_MODEL_DEPLOYER_FLAVOR,
            source="zenml.integrations.mlflow.model_deployers.MLFlowModelDeployer",
            type=StackComponentType.MODEL_DEPLOYER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR,
            source="zenml.integrations.mlflow.experiment_trackers.MLFlowExperimentTracker",
            type=StackComponentType.EXPERIMENT_TRACKER,
            integration=cls.NAME,
        ),
    ]

experiment_trackers special

Initialization of the MLflow experiment tracker.

mlflow_experiment_tracker

Implementation of the MLflow experiment tracker for ZenML.

MLFlowExperimentTracker (BaseExperimentTracker) pydantic-model

Stores Mlflow configuration options.

ZenML should take care of configuring MLflow for you, but should you still need access to the configuration inside your step you can do it using a step context:

from zenml.steps import StepContext

@enable_mlflow
@step
def my_step(context: StepContext, ...)
    context.stack.experiment_tracker  # get the tracking_uri etc. from here

Attributes:

Name Type Description
tracking_uri Optional[str]

The uri of the mlflow tracking server. If no uri is set, your stack must contain a LocalArtifactStore and ZenML will point MLflow to a subdirectory of your artifact store instead.

tracking_username Optional[str]

Username for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_password Optional[str]

Password for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_token Optional[str]

Token for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_insecure_tls bool

Skips verification of TLS connection to the MLflow tracking server if set to True.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
    """Stores Mlflow configuration options.

    ZenML should take care of configuring MLflow for you, but should you still
    need access to the configuration inside your step you can do it using a
    step context:
    ```python
    from zenml.steps import StepContext

    @enable_mlflow
    @step
    def my_step(context: StepContext, ...)
        context.stack.experiment_tracker  # get the tracking_uri etc. from here
    ```

    Attributes:
        tracking_uri: The uri of the mlflow tracking server. If no uri is set,
            your stack must contain a `LocalArtifactStore` and ZenML will
            point MLflow to a subdirectory of your artifact store instead.
        tracking_username: Username for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_password: Password for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_token: Token for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_insecure_tls: Skips verification of TLS connection to the
            MLflow tracking server if set to `True`.
    """

    tracking_uri: Optional[str] = None
    tracking_username: Optional[str] = None
    tracking_password: Optional[str] = None
    tracking_token: Optional[str] = None
    tracking_insecure_tls: bool = False

    # Class Configuration
    FLAVOR: ClassVar[str] = MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR

    @validator("tracking_uri")
    def _ensure_valid_tracking_uri(
        cls, tracking_uri: Optional[str] = None
    ) -> Optional[str]:
        """Ensures that the tracking uri is a valid mlflow tracking uri.

        Args:
            tracking_uri: The tracking uri to validate.

        Returns:
            The tracking uri if it is valid.

        Raises:
            ValueError: If the tracking uri is not valid.
        """
        if tracking_uri:
            valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
            if not any(
                tracking_uri.startswith(scheme) for scheme in valid_schemes
            ):
                raise ValueError(
                    f"MLflow tracking uri does not start with one of the valid "
                    f"schemes {valid_schemes}. See "
                    f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
                    f"for more information."
                )
        return tracking_uri

    @root_validator(skip_on_failure=True)
    def _ensure_authentication_if_necessary(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Ensures that credentials or a token for authentication exist.

        We make this check when running MLflow tracking with a remote backend.

        Args:
            values: The values to validate.

        Returns:
            The validated values.

        Raises:
            ValueError: If neither credentials nor a token are provided.
        """
        tracking_uri = values.get("tracking_uri")

        if tracking_uri and cls.is_remote_tracking_uri(tracking_uri):
            # we need either username + password or a token to authenticate to
            # the remote backend
            basic_auth = values.get("tracking_username") and values.get(
                "tracking_password"
            )
            token_auth = values.get("tracking_token")

            if not (basic_auth or token_auth):
                raise ValueError(
                    f"MLflow experiment tracking with a remote backend "
                    f"{tracking_uri} is only possible when specifying either "
                    f"username and password or an authentication token in your "
                    f"stack component. To update your component, run the "
                    f"following command: `zenml experiment-tracker update "
                    f"{values['name']} --tracking_username=MY_USERNAME "
                    f"--tracking_password=MY_PASSWORD "
                    f"--tracking_token=MY_TOKEN` and specify either your "
                    f"username and password or token."
                )

        return values

    @staticmethod
    def is_remote_tracking_uri(tracking_uri: str) -> bool:
        """Checks whether the given tracking uri is remote or not.

        Args:
            tracking_uri: The tracking uri to check.

        Returns:
            `True` if the tracking uri is remote, `False` otherwise.
        """
        return any(
            tracking_uri.startswith(prefix)
            for prefix in ["http://", "https://"]
        )

    @staticmethod
    def _local_mlflow_backend() -> str:
        """Gets the local MLflow backend inside the ZenML artifact repository directory.

        Returns:
            The MLflow tracking URI for the local MLflow backend.
        """
        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        artifact_store = repo.active_stack.artifact_store
        local_mlflow_backend_uri = os.path.join(artifact_store.path, "mlruns")
        if not os.path.exists(local_mlflow_backend_uri):
            os.makedirs(local_mlflow_backend_uri)
        return "file:" + local_mlflow_backend_uri

    def get_tracking_uri(self) -> str:
        """Returns the configured tracking URI or a local fallback.

        Returns:
            The tracking URI.
        """
        return self.tracking_uri or self._local_mlflow_backend()

    def configure_mlflow(self) -> None:
        """Configures the MLflow tracking URI and any additional credentials."""
        mlflow.set_tracking_uri(self.get_tracking_uri())

        if self.tracking_username:
            os.environ[MLFLOW_TRACKING_USERNAME] = self.tracking_username
        if self.tracking_password:
            os.environ[MLFLOW_TRACKING_PASSWORD] = self.tracking_password
        if self.tracking_token:
            os.environ[MLFLOW_TRACKING_TOKEN] = self.tracking_token
        os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
            "true" if self.tracking_insecure_tls else "false"
        )

    def prepare_step_run(self) -> None:
        """Sets the MLflow tracking uri and credentials."""
        self.configure_mlflow()

    def cleanup_step_run(self) -> None:
        """Resets the MLflow tracking uri."""
        mlflow.set_tracking_uri("")

    @property
    def local_path(self) -> Optional[str]:
        """Path to the local directory where the MLflow artifacts are stored.

        Returns:
            None if configured with a remote tracking URI, otherwise the
            path to the local MLflow artifact store directory.
        """
        tracking_uri = self.get_tracking_uri()
        if self.is_remote_tracking_uri(tracking_uri):
            return None
        else:
            assert tracking_uri.startswith("file:")
            return tracking_uri[5:]

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.

        Returns:
            An optional `StackValidator`.
        """
        if self.tracking_uri:
            # user specified a tracking uri, do nothing
            return None
        else:
            # try to fall back to a tracking uri inside the zenml artifact
            # store. this only works in case of a local artifact store, so we
            # make sure to prevent stack with other artifact stores for now
            return StackValidator(
                custom_validation_function=lambda stack: (
                    isinstance(stack.artifact_store, LocalArtifactStore),
                    "MLflow experiment tracker without a specified tracking "
                    "uri only works with a local artifact store.",
                )
            )

    @property
    def active_experiment(self) -> Optional[Experiment]:
        """Returns the currently active MLflow experiment.

        Returns:
            The active experiment or `None` if no experiment is active.
        """
        step_env = Environment().step_environment

        if not step_env:
            # we're not inside a step
            return None

        mlflow.set_experiment(experiment_name=step_env.pipeline_name)
        return mlflow.get_experiment_by_name(step_env.pipeline_name)

    def _find_active_run(
        self,
    ) -> Tuple[Optional[mlflow.ActiveRun], Optional[str], Optional[str]]:
        """Find the currently active MLflow run.

        Returns:
            The active MLflow run, the experiment id and the run id
        """
        step_env = Environment().step_environment

        if not self.active_experiment or not step_env:
            return None, None, None

        experiment_id = self.active_experiment.experiment_id

        # TODO [ENG-458]: find a solution to avoid race-conditions while
        #  creating the same MLflow run from parallel steps
        runs = mlflow.search_runs(
            experiment_ids=[experiment_id],
            filter_string=f'tags.mlflow.runName = "{step_env.pipeline_run_id}"',
            output_format="list",
        )

        run_id = runs[0].info.run_id if runs else None

        current_active_run = mlflow.active_run()
        if not (
            current_active_run and current_active_run.info.run_id == run_id
        ):
            current_active_run = None

        return current_active_run, experiment_id, run_id

    @property
    def active_run(self) -> Optional[mlflow.ActiveRun]:
        """Returns the currently active MLflow run.

        Returns:
            The active MLflow run.
        """
        step_env = Environment().step_environment
        current_active_run, experiment_id, run_id = self._find_active_run()
        if current_active_run:
            return current_active_run
        else:
            return mlflow.start_run(
                run_id=run_id,
                run_name=step_env.pipeline_run_id,
                experiment_id=experiment_id,
            )

    @property
    def active_nested_run(self) -> Optional[mlflow.ActiveRun]:
        """Returns a nested run in the currently active MLflow run.

        Returns:
            The nested MLflow run.
        """
        step_env = Environment().step_environment
        current_active_run, _, _ = self._find_active_run()
        if current_active_run:
            return mlflow.start_run(run_name=step_env.step_name, nested=True)
        else:
            # Return None
            return current_active_run
active_experiment: Optional[mlflow.entities.experiment.Experiment] property readonly

Returns the currently active MLflow experiment.

Returns:

Type Description
Optional[mlflow.entities.experiment.Experiment]

The active experiment or None if no experiment is active.

active_nested_run: Optional[mlflow.tracking.fluent.ActiveRun] property readonly

Returns a nested run in the currently active MLflow run.

Returns:

Type Description
Optional[mlflow.tracking.fluent.ActiveRun]

The nested MLflow run.

active_run: Optional[mlflow.tracking.fluent.ActiveRun] property readonly

Returns the currently active MLflow run.

Returns:

Type Description
Optional[mlflow.tracking.fluent.ActiveRun]

The active MLflow run.

local_path: Optional[str] property readonly

Path to the local directory where the MLflow artifacts are stored.

Returns:

Type Description
Optional[str]

None if configured with a remote tracking URI, otherwise the path to the local MLflow artifact store directory.

validator: Optional[StackValidator] property readonly

Checks the stack has a LocalArtifactStore if no tracking uri was specified.

Returns:

Type Description
Optional[StackValidator]

An optional StackValidator.

cleanup_step_run(self)

Resets the MLflow tracking uri.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(self) -> None:
    """Resets the MLflow tracking uri."""
    mlflow.set_tracking_uri("")
configure_mlflow(self)

Configures the MLflow tracking URI and any additional credentials.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
    """Configures the MLflow tracking URI and any additional credentials."""
    mlflow.set_tracking_uri(self.get_tracking_uri())

    if self.tracking_username:
        os.environ[MLFLOW_TRACKING_USERNAME] = self.tracking_username
    if self.tracking_password:
        os.environ[MLFLOW_TRACKING_PASSWORD] = self.tracking_password
    if self.tracking_token:
        os.environ[MLFLOW_TRACKING_TOKEN] = self.tracking_token
    os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
        "true" if self.tracking_insecure_tls else "false"
    )
get_tracking_uri(self)

Returns the configured tracking URI or a local fallback.

Returns:

Type Description
str

The tracking URI.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self) -> str:
    """Returns the configured tracking URI or a local fallback.

    Returns:
        The tracking URI.
    """
    return self.tracking_uri or self._local_mlflow_backend()
is_remote_tracking_uri(tracking_uri) staticmethod

Checks whether the given tracking uri is remote or not.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is remote, False otherwise.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
@staticmethod
def is_remote_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is remote or not.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is remote, `False` otherwise.
    """
    return any(
        tracking_uri.startswith(prefix)
        for prefix in ["http://", "https://"]
    )
prepare_step_run(self)

Sets the MLflow tracking uri and credentials.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self) -> None:
    """Sets the MLflow tracking uri and credentials."""
    self.configure_mlflow()

mlflow_step_decorator

Implementation of the MLflow StepDecorator.

enable_mlflow(_step=None, nested=False)

Decorator to enable mlflow for a step function.

Apply this decorator to a ZenML pipeline step to enable MLflow experiment tracking. The MLflow tracking configuration (tracking URI, experiment name, run name) will be automatically configured before the step code is executed, so the step can simply use the mlflow module to log metrics and artifacts.

The simple usage will log metrics into a run created for the pipeline, like so:

@enable_mlflow
@step
def tf_evaluator(
    x_test: np.ndarray,
    y_test: np.ndarray,
    model: tf.keras.Model,
) -> float:
    _, test_acc = model.evaluate(x_test, y_test, verbose=2)
    mlflow.log_metric("val_accuracy", test_acc)
    return test_acc

You can also log parameters, metrics and artifacts into nested runs, which will be children of the pipeline run. You only need to add the parameter nested=True to the decorator, like so:

@enable_mlflow(nested=True)
@step
def tf_evaluator(
    x_test: np.ndarray,
    y_test: np.ndarray,
    model: tf.keras.Model,
) -> float:
    _, test_acc = model.evaluate(x_test, y_test, verbose=2)
    mlflow.log_param("some_param", 2)
    mlflow.log_metric("val_accuracy", test_acc)
    return test_acc

You can also use this decorator with our class-based API like so:

@enable_mlflow
class TFEvaluator(BaseStep):
    def entrypoint(
        self,
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        ...

All MLflow artifacts and metrics logged from all the steps in a pipeline run are by default grouped under a single experiment named after the pipeline. To log MLflow artifacts and metrics from a step in a separate MLflow experiment, pass a custom experiment_name argument value to the decorator.

Parameters:

Name Type Description Default
_step Optional[~S]

The decorated step class.

None
nested bool

Controls whether to create a run as a child of pipeline run. All the the mlflow logging functions using during a step with nested=True will be logged into the child run.

False

Returns:

Type Description
Union[~S, Callable[[~S], ~S]]

The inner decorator which enhances the input step class with mlflow tracking functionality

Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def enable_mlflow(
    _step: Optional[S] = None, nested: bool = False
) -> Union[S, Callable[[S], S]]:
    """Decorator to enable mlflow for a step function.

    Apply this decorator to a ZenML pipeline step to enable MLflow experiment
    tracking. The MLflow tracking configuration (tracking URI, experiment name,
    run name) will be automatically configured before the step code is executed,
    so the step can simply use the `mlflow` module to log metrics and artifacts.

    The simple usage will log metrics into a run created for the pipeline, like
    so:

    ```python
    @enable_mlflow
    @step
    def tf_evaluator(
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        _, test_acc = model.evaluate(x_test, y_test, verbose=2)
        mlflow.log_metric("val_accuracy", test_acc)
        return test_acc
    ```
    You can also log parameters, metrics and artifacts into nested runs, which
    will be children of the pipeline run. You only need to add the parameter
    `nested=True` to the decorator, like so:

    ```python
    @enable_mlflow(nested=True)
    @step
    def tf_evaluator(
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        _, test_acc = model.evaluate(x_test, y_test, verbose=2)
        mlflow.log_param("some_param", 2)
        mlflow.log_metric("val_accuracy", test_acc)
        return test_acc
    ```
    You can also use this decorator with our class-based API like so:

    ```
    @enable_mlflow
    class TFEvaluator(BaseStep):
        def entrypoint(
            self,
            x_test: np.ndarray,
            y_test: np.ndarray,
            model: tf.keras.Model,
        ) -> float:
            ...
    ```

    All MLflow artifacts and metrics logged from all the steps in a pipeline
    run are by default grouped under a single experiment named after the
    pipeline. To log MLflow artifacts and metrics from a step in a separate
    MLflow experiment, pass a custom `experiment_name` argument value to the
    decorator.

    Args:
        _step: The decorated step class.
        nested: Controls whether to create a run as a child of pipeline run.
            All the the mlflow logging functions using during a step with
            `nested=True` will be logged into the child run.

    Returns:
        The inner decorator which enhances the input step class with mlflow
        tracking functionality
    """

    def inner_decorator(_step: S) -> S:

        logger.debug(
            "Applying 'enable_mlflow' decorator to step %s", _step.__name__
        )
        if not issubclass(_step, BaseStep):
            raise RuntimeError(
                "The `enable_mlflow` decorator can only be applied to a ZenML "
                "`step` decorated function or a BaseStep subclass."
            )
        source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
        new_entrypoint = mlflow_step_entrypoint(nested=nested)(source_fn)
        if _step._created_by_functional_api():
            # If the step was created by the functional API, the old entrypoint
            # was a static method -> make sure the new one is as well
            new_entrypoint = staticmethod(new_entrypoint)

        setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
        return _step

    if _step is None:
        return inner_decorator
    else:
        return inner_decorator(_step)
mlflow_step_entrypoint(nested=False)

Decorator for a step entrypoint to enable mlflow.

Parameters:

Name Type Description Default
nested bool

Controls whether to create a run as a child of pipeline run. All the the mlflow logging functions using during a step with nested=True will be logged into the child run.

False

Returns:

Type Description
Callable[[~F], ~F]

the input function enhanced with mlflow profiling functionality

Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def mlflow_step_entrypoint(nested: bool = False) -> Callable[[F], F]:
    """Decorator for a step entrypoint to enable mlflow.

    Args:
        nested: Controls whether to create a run as a child of pipeline run.
            All the the mlflow logging functions using during a step with
            `nested=True` will be logged into the child run.

    Returns:
        the input function enhanced with mlflow profiling functionality
    """

    def inner_decorator(func: F) -> F:

        logger.debug(
            "Applying 'mlflow_step_entrypoint' decorator to step entrypoint %s",
            func.__name__,
        )

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa
            logger.debug(
                "Setting up MLflow backend before running step entrypoint %s",
                func.__name__,
            )
            experiment_tracker = Repository(  # type: ignore[call-arg]
                skip_repository_check=True
            ).active_stack.experiment_tracker

            if not isinstance(experiment_tracker, MLFlowExperimentTracker):
                raise get_missing_mlflow_experiment_tracker_error()

            # Check if there is an active run to nest the run
            active_run = experiment_tracker.active_run
            if not active_run:
                raise RuntimeError("No active mlflow run configured.")

            if nested:
                active_nested_run = experiment_tracker.active_nested_run
                # At this point active_nested_run can never be `None` as this
                # would mean that there is not parent active_run, in which case
                # the previous runtime error would have been raised. The following
                # test is to avoid pylint errors
                if not active_nested_run:
                    raise RuntimeError(
                        "No active mlflow run configured to create a nested run."
                    )
                with active_run:
                    with active_nested_run:
                        return func(*args, **kwargs)
            else:
                with active_run:
                    return func(*args, **kwargs)

        return cast(F, wrapper)

    return inner_decorator

mlflow_utils

Implementation of utils specific to the MLflow integration.

get_missing_mlflow_experiment_tracker_error()

Returns description of how to add an MLflow experiment tracker to your stack.

Returns:

Type Description
ValueError

If no MLflow experiment tracker is registered in the active stack.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
    """Returns description of how to add an MLflow experiment tracker to your stack.

    Returns:
        ValueError: If no MLflow experiment tracker is registered in the active stack.
    """
    return ValueError(
        "The active stack needs to have a MLflow experiment tracker "
        "component registered to be able to track experiments using "
        "MLflow. You can create a new stack with a MLflow experiment "
        "tracker component or update your existing stack to add this "
        "component, e.g.:\n\n"
        "  'zenml experiment-tracker register mlflow_tracker "
        "--type=mlflow'\n"
        "  'zenml stack register stack-name -e mlflow_tracker ...'\n"
    )
get_tracking_uri()

Gets the MLflow tracking URI from the active experiment tracking stack component.

noqa: DAR401

Returns:

Type Description
str

MLflow tracking URI.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
    """Gets the MLflow tracking URI from the active experiment tracking stack component.

    # noqa: DAR401

    Returns:
        MLflow tracking URI.
    """
    tracker = Repository().active_stack.experiment_tracker
    if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    return tracker.get_tracking_uri()

model_deployers special

Initialization of the MLflow model deployers.

mlflow_model_deployer

Implementation of the MLflow model deployer.

MLFlowModelDeployer (BaseModelDeployer) pydantic-model

MLflow implementation of the BaseModelDeployer.

Attributes:

Name Type Description
service_path str

the path where the local MLflow deployment service

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
    """MLflow implementation of the BaseModelDeployer.

    Attributes:
        service_path: the path where the local MLflow deployment service
        configuration, PID and log files are stored.
    """

    service_path: str = ""

    # Class Configuration
    FLAVOR: ClassVar[str] = MLFLOW_MODEL_DEPLOYER_FLAVOR

    @root_validator(skip_on_failure=True)
    def set_service_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Sets the service_path attribute value according to the component UUID.

        Args:
            values: the dictionary of values to be validated.

        Returns:
            The validated dictionary of values.
        """
        if values.get("service_path"):
            return values

        # not likely to happen, due to Pydantic validation, but mypy complains
        assert "uuid" in values

        values["service_path"] = cls.get_service_path(values["uuid"])
        return values

    @staticmethod
    def get_service_path(uuid: uuid.UUID) -> str:
        """Get the path where local MLflow service information is stored.

        This includes the deployment service configuration, PID and log files are stored.

        Args:
            uuid: The UUID of the MLflow model deployer.

        Returns:
            The service path.
        """
        service_path = os.path.join(
            get_global_config_directory(),
            LOCAL_STORES_DIRECTORY_NAME,
            str(uuid),
        )
        create_dir_recursive_if_not_exists(service_path)
        return service_path

    @property
    def local_path(self) -> str:
        """Returns the path to the root directory.

        This is where all configurations for MLflow deployment daemon processes are stored.

        Returns:
            The path to the local service root directory.
        """
        return self.service_path

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "MLFlowDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information relevant to the user.

        Args:
            service_instance: Instance of a SeldonDeploymentService

        Returns:
            A dictionary containing the information.
        """
        return {
            "PREDICTION_URL": service_instance.endpoint.prediction_url,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "SERVICE_PATH": service_instance.status.runtime_path,
            "DAEMON_PID": str(service_instance.status.pid),
        }

    @staticmethod
    def get_active_model_deployer() -> "MLFlowModelDeployer":
        """Returns the MLFlowModelDeployer component of the active stack.

        Args:
            None

        Returns:
            The MLFlowModelDeployer component of the active stack.

        Raises:
            TypeError: If the active stack does not contain an MLFlowModelDeployer component.
        """
        model_deployer = Repository(  # type: ignore[call-arg]
            skip_repository_check=True
        ).active_stack.model_deployer

        if not model_deployer or not isinstance(
            model_deployer, MLFlowModelDeployer
        ):
            raise TypeError(
                f"The active stack needs to have an MLflow model deployer "
                f"component registered to be able to deploy models with MLflow. "
                f"You can create a new stack with an MLflow model "
                f"deployer component or update your existing stack to add this "
                f"component, e.g.:\n\n"
                f"  'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
                f"  'zenml stack create stack-name -d mlflow ...'\n"
            )
        return model_deployer

    def deploy_model(
        self,
        config: ServiceConfig,
        replace: bool = False,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new MLflow deployment service or update an existing one.

        This should serve the supplied model and deployment configuration.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new MLflow
            deployment server to reflect the model and other configuration
            parameters specified in the supplied MLflow service `config`.

          * if `replace` is True, this method will first attempt to find an
            existing MLflow deployment service that is *equivalent* to the
            supplied configuration parameters. Two or more MLflow deployment
            services are considered equivalent if they have the same
            `pipeline_name`, `pipeline_step_name` and `model_name` configuration
            parameters. To put it differently, two MLflow deployment services
            are equivalent if they serve versions of the same model deployed by
            the same pipeline step. If an equivalent MLflow deployment is found,
            it will be updated in place to reflect the new configuration
            parameters.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new MLflow deployment
        server for each new model version. If multiple equivalent MLflow
        deployment servers are found, one is selected at random to be updated
        and the others are deleted.

        Args:
            config: the configuration of the model to be deployed with MLflow.
            replace: set this flag to True to find and update an equivalent
                MLflow deployment server with the new model instead of
                creating and starting a new deployment server.
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the MLflow
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML MLflow deployment service object that can be used to
            interact with the MLflow model server.
        """
        config = cast(MLFlowDeploymentConfig, config)
        service = None

        # if replace is True, remove all existing services
        if replace is True:
            existing_services = self.find_model_server(
                pipeline_name=config.pipeline_name,
                pipeline_step_name=config.pipeline_step_name,
                model_name=config.model_name,
            )

            for existing_service in existing_services:
                if service is None:
                    # keep the most recently created service
                    service = cast(MLFlowDeploymentService, existing_service)
                try:
                    # delete the older services and don't wait for them to
                    # be deprovisioned
                    self._clean_up_existing_service(
                        existing_service=cast(
                            MLFlowDeploymentService, existing_service
                        ),
                        timeout=timeout,
                        force=True,
                    )
                except RuntimeError:
                    # ignore errors encountered while stopping old services
                    pass
        if service:
            logger.info(
                f"Updating an existing MLflow deployment service: {service}"
            )

            # set the root runtime path with the stack component's UUID
            config.root_runtime_path = self.local_path
            service.stop(timeout=timeout, force=True)
            service.update(config)
            service.start(timeout=timeout)
        else:
            # create a new MLFlowDeploymentService instance
            service = self._create_new_service(timeout, config)
            logger.info(f"Created a new MLflow deployment service: {service}")

        return cast(BaseService, service)

    def _clean_up_existing_service(
        self,
        timeout: int,
        force: bool,
        existing_service: MLFlowDeploymentService,
    ) -> None:
        # stop the older service
        existing_service.stop(timeout=timeout, force=force)

        # delete the old configuration file
        service_directory_path = existing_service.status.runtime_path or ""
        shutil.rmtree(service_directory_path)

    # the step will receive a config from the user that mentions the number
    # of workers etc.the step implementation will create a new config using
    # all values from the user and add values like pipeline name, model_uri
    def _create_new_service(
        self, timeout: int, config: MLFlowDeploymentConfig
    ) -> MLFlowDeploymentService:
        """Creates a new MLFlowDeploymentService.

        Args:
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated.
            config: the configuration of the model to be deployed with MLflow.

        Returns:
            The MLFlowDeploymentService object that can be used to interact
            with the MLflow model server.
        """
        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        # create a new service for the new model
        service = MLFlowDeploymentService(config)
        service.start(timeout=timeout)

        return service

    def find_model_server(
        self,
        running: bool = False,
        service_uuid: Optional[UUID] = None,
        pipeline_name: Optional[str] = None,
        pipeline_run_id: Optional[str] = None,
        pipeline_step_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_type: Optional[str] = None,
    ) -> List[BaseService]:
        """Finds one or more model servers that match the given criteria.

        Args:
            running: If true, only running services will be returned.
            service_uuid: The UUID of the service that was originally used
                to deploy the model.
            pipeline_name: Name of the pipeline that the deployed model was part
                of.
            pipeline_run_id: ID of the pipeline run which the deployed model
                was part of.
            pipeline_step_name: The name of the pipeline model deployment step
                that deployed the model.
            model_name: Name of the deployed model.
            model_uri: URI of the deployed model.
            model_type: Type/format of the deployed model. Not used in this
                MLflow case.

        Returns:
            One or more Service objects representing model servers that match
            the input search criteria.

        Raises:
            TypeError: if any of the input arguments are of an invalid type.
        """
        services = []
        config = MLFlowDeploymentConfig(
            model_name=model_name or "",
            model_uri=model_uri or "",
            pipeline_name=pipeline_name or "",
            pipeline_run_id=pipeline_run_id or "",
            pipeline_step_name=pipeline_step_name or "",
        )

        # find all services that match the input criteria
        for root, _, files in os.walk(self.local_path):
            if service_uuid and Path(root).name != str(service_uuid):
                continue
            for file in files:
                if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
                    service_config_path = os.path.join(root, file)
                    logger.debug(
                        "Loading service daemon configuration from %s",
                        service_config_path,
                    )
                    existing_service_config = None
                    with open(service_config_path, "r") as f:
                        existing_service_config = f.read()
                    existing_service = ServiceRegistry().load_service_from_json(
                        existing_service_config
                    )
                    if not isinstance(
                        existing_service, MLFlowDeploymentService
                    ):
                        raise TypeError(
                            f"Expected service type MLFlowDeploymentService but got "
                            f"{type(existing_service)} instead"
                        )
                    existing_service.update_status()
                    if self._matches_search_criteria(existing_service, config):
                        if not running or existing_service.is_running:
                            services.append(cast(BaseService, existing_service))

        return services

    def _matches_search_criteria(
        self,
        existing_service: MLFlowDeploymentService,
        config: MLFlowDeploymentConfig,
    ) -> bool:
        """Returns true if a service matches the input criteria.

        If any of the values in the input criteria are None, they are ignored.
        This allows listing services just by common pipeline names or step
        names, etc.

        Args:
            existing_service: The materialized Service instance derived from
                the config of the older (existing) service
            config: The MLFlowDeploymentConfig object passed to the
                deploy_model function holding parameters of the new service
                to be created.

        Returns:
            True if the service matches the input criteria.
        """
        existing_service_config = existing_service.config

        # check if the existing service matches the input criteria
        if (
            (
                not config.pipeline_name
                or existing_service_config.pipeline_name == config.pipeline_name
            )
            and (
                not config.model_name
                or existing_service_config.model_name == config.model_name
            )
            and (
                not config.pipeline_step_name
                or existing_service_config.pipeline_step_name
                == config.pipeline_step_name
            )
            and (
                not config.pipeline_run_id
                or existing_service_config.pipeline_run_id
                == config.pipeline_run_id
            )
        ):
            return True

        return False

    def stop_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to stop a model server.

        Args:
            uuid: UUID of the model server to stop.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, stop it
        if existing_services:
            existing_services[0].stop(timeout=timeout, force=force)

    def start_model_server(
        self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
    ) -> None:
        """Method to start a model server.

        Args:
            uuid: UUID of the model server to start.
            timeout: Timeout in seconds to wait for the service to start.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, start it
        if existing_services:
            existing_services[0].start(timeout=timeout)

    def delete_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to delete all configuration of a model server.

        Args:
            uuid: UUID of the model server to delete.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, clean it up
        if existing_services:
            service = cast(MLFlowDeploymentService, existing_services[0])
            self._clean_up_existing_service(
                existing_service=service, timeout=timeout, force=force
            )
local_path: str property readonly

Returns the path to the root directory.

This is where all configurations for MLflow deployment daemon processes are stored.

Returns:

Type Description
str

The path to the local service root directory.

delete_model_server(self, uuid, timeout=10, force=False)

Method to delete all configuration of a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to delete.

required
timeout int

Timeout in seconds to wait for the service to stop.

10
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def delete_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to delete all configuration of a model server.

    Args:
        uuid: UUID of the model server to delete.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, clean it up
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])
        self._clean_up_existing_service(
            existing_service=service, timeout=timeout, force=force
        )
deploy_model(self, config, replace=False, timeout=10)

Create a new MLflow deployment service or update an existing one.

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow service config.

  • if replace is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new MLflow deployment server for each new model version. If multiple equivalent MLflow deployment servers are found, one is selected at random to be updated and the others are deleted.

Parameters:

Name Type Description Default
config ServiceConfig

the configuration of the model to be deployed with MLflow.

required
replace bool

set this flag to True to find and update an equivalent MLflow deployment server with the new model instead of creating and starting a new deployment server.

False
timeout int

the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start.

10

Returns:

Type Description
BaseService

The ZenML MLflow deployment service object that can be used to interact with the MLflow model server.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def deploy_model(
    self,
    config: ServiceConfig,
    replace: bool = False,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new MLflow deployment service or update an existing one.

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new MLflow
        deployment server to reflect the model and other configuration
        parameters specified in the supplied MLflow service `config`.

      * if `replace` is True, this method will first attempt to find an
        existing MLflow deployment service that is *equivalent* to the
        supplied configuration parameters. Two or more MLflow deployment
        services are considered equivalent if they have the same
        `pipeline_name`, `pipeline_step_name` and `model_name` configuration
        parameters. To put it differently, two MLflow deployment services
        are equivalent if they serve versions of the same model deployed by
        the same pipeline step. If an equivalent MLflow deployment is found,
        it will be updated in place to reflect the new configuration
        parameters.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new MLflow deployment
    server for each new model version. If multiple equivalent MLflow
    deployment servers are found, one is selected at random to be updated
    and the others are deleted.

    Args:
        config: the configuration of the model to be deployed with MLflow.
        replace: set this flag to True to find and update an equivalent
            MLflow deployment server with the new model instead of
            creating and starting a new deployment server.
        timeout: the timeout in seconds to wait for the MLflow server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the MLflow
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML MLflow deployment service object that can be used to
        interact with the MLflow model server.
    """
    config = cast(MLFlowDeploymentConfig, config)
    service = None

    # if replace is True, remove all existing services
    if replace is True:
        existing_services = self.find_model_server(
            pipeline_name=config.pipeline_name,
            pipeline_step_name=config.pipeline_step_name,
            model_name=config.model_name,
        )

        for existing_service in existing_services:
            if service is None:
                # keep the most recently created service
                service = cast(MLFlowDeploymentService, existing_service)
            try:
                # delete the older services and don't wait for them to
                # be deprovisioned
                self._clean_up_existing_service(
                    existing_service=cast(
                        MLFlowDeploymentService, existing_service
                    ),
                    timeout=timeout,
                    force=True,
                )
            except RuntimeError:
                # ignore errors encountered while stopping old services
                pass
    if service:
        logger.info(
            f"Updating an existing MLflow deployment service: {service}"
        )

        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        service.stop(timeout=timeout, force=True)
        service.update(config)
        service.start(timeout=timeout)
    else:
        # create a new MLFlowDeploymentService instance
        service = self._create_new_service(timeout, config)
        logger.info(f"Created a new MLflow deployment service: {service}")

    return cast(BaseService, service)
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)

Finds one or more model servers that match the given criteria.

Parameters:

Name Type Description Default
running bool

If true, only running services will be returned.

False
service_uuid Optional[uuid.UUID]

The UUID of the service that was originally used to deploy the model.

None
pipeline_name Optional[str]

Name of the pipeline that the deployed model was part of.

None
pipeline_run_id Optional[str]

ID of the pipeline run which the deployed model was part of.

None
pipeline_step_name Optional[str]

The name of the pipeline model deployment step that deployed the model.

None
model_name Optional[str]

Name of the deployed model.

None
model_uri Optional[str]

URI of the deployed model.

None
model_type Optional[str]

Type/format of the deployed model. Not used in this MLflow case.

None

Returns:

Type Description
List[zenml.services.service.BaseService]

One or more Service objects representing model servers that match the input search criteria.

Exceptions:

Type Description
TypeError

if any of the input arguments are of an invalid type.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def find_model_server(
    self,
    running: bool = False,
    service_uuid: Optional[UUID] = None,
    pipeline_name: Optional[str] = None,
    pipeline_run_id: Optional[str] = None,
    pipeline_step_name: Optional[str] = None,
    model_name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_type: Optional[str] = None,
) -> List[BaseService]:
    """Finds one or more model servers that match the given criteria.

    Args:
        running: If true, only running services will be returned.
        service_uuid: The UUID of the service that was originally used
            to deploy the model.
        pipeline_name: Name of the pipeline that the deployed model was part
            of.
        pipeline_run_id: ID of the pipeline run which the deployed model
            was part of.
        pipeline_step_name: The name of the pipeline model deployment step
            that deployed the model.
        model_name: Name of the deployed model.
        model_uri: URI of the deployed model.
        model_type: Type/format of the deployed model. Not used in this
            MLflow case.

    Returns:
        One or more Service objects representing model servers that match
        the input search criteria.

    Raises:
        TypeError: if any of the input arguments are of an invalid type.
    """
    services = []
    config = MLFlowDeploymentConfig(
        model_name=model_name or "",
        model_uri=model_uri or "",
        pipeline_name=pipeline_name or "",
        pipeline_run_id=pipeline_run_id or "",
        pipeline_step_name=pipeline_step_name or "",
    )

    # find all services that match the input criteria
    for root, _, files in os.walk(self.local_path):
        if service_uuid and Path(root).name != str(service_uuid):
            continue
        for file in files:
            if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
                service_config_path = os.path.join(root, file)
                logger.debug(
                    "Loading service daemon configuration from %s",
                    service_config_path,
                )
                existing_service_config = None
                with open(service_config_path, "r") as f:
                    existing_service_config = f.read()
                existing_service = ServiceRegistry().load_service_from_json(
                    existing_service_config
                )
                if not isinstance(
                    existing_service, MLFlowDeploymentService
                ):
                    raise TypeError(
                        f"Expected service type MLFlowDeploymentService but got "
                        f"{type(existing_service)} instead"
                    )
                existing_service.update_status()
                if self._matches_search_criteria(existing_service, config):
                    if not running or existing_service.is_running:
                        services.append(cast(BaseService, existing_service))

    return services
get_active_model_deployer() staticmethod

Returns the MLFlowModelDeployer component of the active stack.

Returns:

Type Description
MLFlowModelDeployer

The MLFlowModelDeployer component of the active stack.

Exceptions:

Type Description
TypeError

If the active stack does not contain an MLFlowModelDeployer component.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "MLFlowModelDeployer":
    """Returns the MLFlowModelDeployer component of the active stack.

    Args:
        None

    Returns:
        The MLFlowModelDeployer component of the active stack.

    Raises:
        TypeError: If the active stack does not contain an MLFlowModelDeployer component.
    """
    model_deployer = Repository(  # type: ignore[call-arg]
        skip_repository_check=True
    ).active_stack.model_deployer

    if not model_deployer or not isinstance(
        model_deployer, MLFlowModelDeployer
    ):
        raise TypeError(
            f"The active stack needs to have an MLflow model deployer "
            f"component registered to be able to deploy models with MLflow. "
            f"You can create a new stack with an MLflow model "
            f"deployer component or update your existing stack to add this "
            f"component, e.g.:\n\n"
            f"  'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
            f"  'zenml stack create stack-name -d mlflow ...'\n"
        )
    return model_deployer
get_model_server_info(service_instance) staticmethod

Return implementation specific information relevant to the user.

Parameters:

Name Type Description Default
service_instance MLFlowDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the information.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        A dictionary containing the information.
    """
    return {
        "PREDICTION_URL": service_instance.endpoint.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "SERVICE_PATH": service_instance.status.runtime_path,
        "DAEMON_PID": str(service_instance.status.pid),
    }
get_service_path(uuid) staticmethod

Get the path where local MLflow service information is stored.

This includes the deployment service configuration, PID and log files are stored.

Parameters:

Name Type Description Default
uuid UUID

The UUID of the MLflow model deployer.

required

Returns:

Type Description
str

The service path.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(uuid: uuid.UUID) -> str:
    """Get the path where local MLflow service information is stored.

    This includes the deployment service configuration, PID and log files are stored.

    Args:
        uuid: The UUID of the MLflow model deployer.

    Returns:
        The service path.
    """
    service_path = os.path.join(
        get_global_config_directory(),
        LOCAL_STORES_DIRECTORY_NAME,
        str(uuid),
    )
    create_dir_recursive_if_not_exists(service_path)
    return service_path
set_service_path(values) classmethod

Sets the service_path attribute value according to the component UUID.

Parameters:

Name Type Description Default
values Dict[str, Any]

the dictionary of values to be validated.

required

Returns:

Type Description
Dict[str, Any]

The validated dictionary of values.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@root_validator(skip_on_failure=True)
def set_service_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Sets the service_path attribute value according to the component UUID.

    Args:
        values: the dictionary of values to be validated.

    Returns:
        The validated dictionary of values.
    """
    if values.get("service_path"):
        return values

    # not likely to happen, due to Pydantic validation, but mypy complains
    assert "uuid" in values

    values["service_path"] = cls.get_service_path(values["uuid"])
    return values
start_model_server(self, uuid, timeout=10)

Method to start a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to start.

required
timeout int

Timeout in seconds to wait for the service to start.

10
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def start_model_server(
    self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
    """Method to start a model server.

    Args:
        uuid: UUID of the model server to start.
        timeout: Timeout in seconds to wait for the service to start.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, start it
    if existing_services:
        existing_services[0].start(timeout=timeout)
stop_model_server(self, uuid, timeout=10, force=False)

Method to stop a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to stop.

required
timeout int

Timeout in seconds to wait for the service to stop.

10
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def stop_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to stop a model server.

    Args:
        uuid: UUID of the model server to stop.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, stop it
    if existing_services:
        existing_services[0].stop(timeout=timeout, force=force)

services special

Initialization of the MLflow Service.

mlflow_deployment

Implementation of the MLflow deployment functionality.

MLFlowDeploymentConfig (LocalDaemonServiceConfig) pydantic-model

MLflow model deployment configuration.

Attributes:

Name Type Description
model_uri str

URI of the MLflow model to serve

model_name str

the name of the model

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
    """MLflow model deployment configuration.

    Attributes:
        model_uri: URI of the MLflow model to serve
        model_name: the name of the model
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
    """

    model_uri: str
    model_name: str
    workers: int = 1
    mlserver: bool = False
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint) pydantic-model

A service endpoint exposed by the MLflow deployment daemon.

Attributes:

Name Type Description
config MLFlowDeploymentEndpointConfig

service endpoint configuration

monitor HTTPEndpointHealthMonitor

optional service endpoint health monitor

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
    """A service endpoint exposed by the MLflow deployment daemon.

    Attributes:
        config: service endpoint configuration
        monitor: optional service endpoint health monitor
    """

    config: MLFlowDeploymentEndpointConfig
    monitor: HTTPEndpointHealthMonitor

    @property
    def prediction_url(self) -> Optional[str]:
        """Gets the prediction URL for the endpoint.

        Returns:
            the prediction URL for the endpoint
        """
        uri = self.status.uri
        if not uri:
            return None
        return f"{uri}{self.config.prediction_url_path}"
prediction_url: Optional[str] property readonly

Gets the prediction URL for the endpoint.

Returns:

Type Description
Optional[str]

the prediction URL for the endpoint

MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig) pydantic-model

MLflow daemon service endpoint configuration.

Attributes:

Name Type Description
prediction_url_path str

URI subpath for prediction requests

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
    """MLflow daemon service endpoint configuration.

    Attributes:
        prediction_url_path: URI subpath for prediction requests
    """

    prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService) pydantic-model

MLflow deployment service used to start a local prediction server for MLflow models.

Attributes:

Name Type Description
SERVICE_TYPE ClassVar[zenml.services.service_type.ServiceType]

a service type descriptor with information describing the MLflow deployment service class

config MLFlowDeploymentConfig

service configuration

endpoint MLFlowDeploymentEndpoint

optional service endpoint

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService):
    """MLflow deployment service used to start a local prediction server for MLflow models.

    Attributes:
        SERVICE_TYPE: a service type descriptor with information describing
            the MLflow deployment service class
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="mlflow-deployment",
        type="model-serving",
        flavor="mlflow",
        description="MLflow prediction service",
    )

    config: MLFlowDeploymentConfig
    endpoint: MLFlowDeploymentEndpoint

    def __init__(
        self,
        config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
        **attrs: Any,
    ) -> None:
        """Initialize the MLflow deployment service.

        Args:
            config: service configuration
            attrs: additional attributes to set on the service
        """
        # ensure that the endpoint is created before the service is initialized
        # TODO [ENG-700]: implement a service factory or builder for MLflow
        #   deployment services
        if (
            isinstance(config, MLFlowDeploymentConfig)
            and "endpoint" not in attrs
        ):
            if config.mlserver:
                prediction_url_path = MLSERVER_PREDICTION_URL_PATH
                healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
                use_head_request = False
            else:
                prediction_url_path = MLFLOW_PREDICTION_URL_PATH
                healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
                use_head_request = True

            endpoint = MLFlowDeploymentEndpoint(
                config=MLFlowDeploymentEndpointConfig(
                    protocol=ServiceEndpointProtocol.HTTP,
                    prediction_url_path=prediction_url_path,
                ),
                monitor=HTTPEndpointHealthMonitor(
                    config=HTTPEndpointHealthMonitorConfig(
                        healthcheck_uri_path=healthcheck_uri_path,
                        use_head_request=use_head_request,
                    )
                ),
            )
            attrs["endpoint"] = endpoint
        super().__init__(config=config, **attrs)

    def run(self) -> None:
        """Start the service."""
        logger.info(
            "Starting MLflow prediction service as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()

        try:
            serve_kwargs: Dict[str, Any] = {}
            # MLflow version 1.26 introduces an additional mandatory
            # `timeout` argument to the `PyFuncBackend.serve` function
            if int(MLFLOW_VERSION.split(".")[1]) >= 26:
                serve_kwargs["timeout"] = None

            backend = PyFuncBackend(
                config={},
                no_conda=True,
                workers=self.config.workers,
                install_mlflow=False,
            )
            backend.serve(
                model_uri=self.config.model_uri,
                port=self.endpoint.status.port,
                host="localhost",
                enable_mlserver=self.config.mlserver,
                **serve_kwargs,
            )
        except KeyboardInterrupt:
            logger.info(
                "MLflow prediction service stopped. Resuming normal execution."
            )

    @property
    def prediction_url(self) -> Optional[str]:
        """Get the URI where the prediction service is answering requests.

        Returns:
            The URI where the prediction service can be contacted to process
            HTTP/REST inference requests, or None, if the service isn't running.
        """
        if not self.is_running:
            return None
        return self.endpoint.prediction_url

    def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
        """Make a prediction using the service.

        Args:
            request: a numpy array representing the request

        Returns:
            A numpy array representing the prediction returned by the service.

        Raises:
            Exception: if the service is not running
            ValueError: if the prediction endpoint is unknown.
        """
        if not self.is_running:
            raise Exception(
                "MLflow prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.endpoint.prediction_url is not None:
            response = requests.post(
                self.endpoint.prediction_url,
                json={"instances": request.tolist()},
            )
        else:
            raise ValueError("No endpoint known for prediction.")
        response.raise_for_status()
        return np.array(response.json())
prediction_url: Optional[str] property readonly

Get the URI where the prediction service is answering requests.

Returns:

Type Description
Optional[str]

The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running.

__init__(self, config, **attrs) special

Initialize the MLflow deployment service.

Parameters:

Name Type Description Default
config Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]]

service configuration

required
attrs Any

additional attributes to set on the service

{}
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
    self,
    config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialize the MLflow deployment service.

    Args:
        config: service configuration
        attrs: additional attributes to set on the service
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-700]: implement a service factory or builder for MLflow
    #   deployment services
    if (
        isinstance(config, MLFlowDeploymentConfig)
        and "endpoint" not in attrs
    ):
        if config.mlserver:
            prediction_url_path = MLSERVER_PREDICTION_URL_PATH
            healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
            use_head_request = False
        else:
            prediction_url_path = MLFLOW_PREDICTION_URL_PATH
            healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
            use_head_request = True

        endpoint = MLFlowDeploymentEndpoint(
            config=MLFlowDeploymentEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                prediction_url_path=prediction_url_path,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path=healthcheck_uri_path,
                    use_head_request=use_head_request,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request NDArray[Any]

a numpy array representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not running

ValueError

if the prediction endpoint is unknown.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a numpy array representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not running
        ValueError: if the prediction endpoint is unknown.
    """
    if not self.is_running:
        raise Exception(
            "MLflow prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.endpoint.prediction_url is not None:
        response = requests.post(
            self.endpoint.prediction_url,
            json={"instances": request.tolist()},
        )
    else:
        raise ValueError("No endpoint known for prediction.")
    response.raise_for_status()
    return np.array(response.json())
run(self)

Start the service.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
    """Start the service."""
    logger.info(
        "Starting MLflow prediction service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        serve_kwargs: Dict[str, Any] = {}
        # MLflow version 1.26 introduces an additional mandatory
        # `timeout` argument to the `PyFuncBackend.serve` function
        if int(MLFLOW_VERSION.split(".")[1]) >= 26:
            serve_kwargs["timeout"] = None

        backend = PyFuncBackend(
            config={},
            no_conda=True,
            workers=self.config.workers,
            install_mlflow=False,
        )
        backend.serve(
            model_uri=self.config.model_uri,
            port=self.endpoint.status.port,
            host="localhost",
            enable_mlserver=self.config.mlserver,
            **serve_kwargs,
        )
    except KeyboardInterrupt:
        logger.info(
            "MLflow prediction service stopped. Resuming normal execution."
        )

steps special

Initialization of the MLflow standard interface steps.

mlflow_deployer

Implementation of the MLflow model deployer pipeline step.

MLFlowDeployerConfig (BaseStepConfig) pydantic-model

Model deployer step configuration for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
    """Model deployer step configuration for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    model_uri: str = ""
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
mlflow_model_deployer_step (BaseStep)

Model deployer pipeline step for MLflow.

noqa: DAR401

Parameters:

Name Type Description Default
deploy_decision

whether to deploy the model or not

required
model

the model artifact to deploy

required
config

configuration for the deployer step

required

Returns:

Type Description

MLflow deployment service

CONFIG_CLASS (BaseStepConfig) pydantic-model

Model deployer step configuration for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
    """Model deployer step configuration for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    model_uri: str = ""
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
entrypoint(deploy_decision, model, config) staticmethod

Model deployer pipeline step for MLflow.

noqa: DAR401

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

required
model ModelArtifact

the model artifact to deploy

required
config MLFlowDeployerConfig

configuration for the deployer step

required

Returns:

Type Description
MLFlowDeploymentService

MLflow deployment service

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
@enable_mlflow
@step(enable_cache=False)
def mlflow_model_deployer_step(
    deploy_decision: bool,
    model: ModelArtifact,
    config: MLFlowDeployerConfig,
) -> MLFlowDeploymentService:
    """Model deployer pipeline step for MLflow.

    # noqa: DAR401

    Args:
        deploy_decision: whether to deploy the model or not
        model: the model artifact to deploy
        config: configuration for the deployer step

    Returns:
        MLflow deployment service
    """
    model_deployer = MLFlowModelDeployer.get_active_model_deployer()

    # fetch the MLflow artifacts logged during the pipeline run
    experiment_tracker = Repository(  # type: ignore[call-arg]
        skip_repository_check=True
    ).active_stack.experiment_tracker

    if not isinstance(experiment_tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    client = MlflowClient()
    model_uri = ""
    mlflow_run = experiment_tracker.active_run
    if mlflow_run and client.list_artifacts(
        mlflow_run.info.run_id, config.model_name
    ):
        model_uri = get_artifact_uri(config.model_name)

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    # fetch existing services with same pipeline name, step name and model name
    existing_services = model_deployer.find_model_server(
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        model_name=config.model_name,
    )

    # create a config for the new model service
    predictor_cfg = MLFlowDeploymentConfig(
        model_name=config.model_name or "",
        model_uri=model_uri,
        workers=config.workers,
        mlserver=config.mlserver,
        pipeline_name=pipeline_name,
        pipeline_run_id=run_id,
        pipeline_step_name=step_name,
    )

    # Creating a new service with inactive state and status by default
    service = MLFlowDeploymentService(predictor_cfg)
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])

    # check for conditions to deploy the model
    if not model_uri:
        # an MLflow model was not trained in the current run, so we simply reuse
        # the currently running service created for the same model, if any
        if not existing_services:
            logger.warning(
                f"An MLflow model with name `{config.model_name}` was not "
                f"logged in the current pipeline run and no running MLflow "
                f"model server was found. Please ensure that your pipeline "
                f"includes an `@enable_mlflow` decorated step that trains a "
                f"model and logs it to MLflow. This could also happen if "
                f"the current pipeline run did not log an MLflow model  "
                f"because the training step was cached."
            )
            # return an inactive service just because we have to return
            # something
            return service
        logger.info(
            f"An MLflow model with name `{config.model_name}` was not "
            f"trained in the current pipeline run. Reusing the existing "
            f"MLflow model server."
        )
        if not service.is_running:
            service.start(config.timeout)

        # return the existing service
        return service

    # even when the deploy decision is negative, if an existing model server
    # is not running for this pipeline/step, we still have to serve the
    # current model, to ensure that a model server is available at all times
    if not deploy_decision and existing_services:
        logger.info(
            f"Skipping model deployment because the model quality does not "
            f"meet the criteria. Reusing last model server deployed by step "
            f"'{step_name}' and pipeline '{pipeline_name}' for model "
            f"'{config.model_name}'..."
        )
        # even when the deploy decision is negative, we still need to start
        # the previous model server if it is no longer running, to ensure
        # that a model server is available at all times
        if not service.is_running:
            service.start(config.timeout)
        return service

    # create a new model deployment and replace an old one if it exists
    new_service = cast(
        MLFlowDeploymentService,
        model_deployer.deploy_model(
            replace=True,
            config=predictor_cfg,
            timeout=config.timeout,
        ),
    )

    logger.info(
        f"MLflow deployment service started and reachable at:\n"
        f"    {new_service.prediction_url}\n"
    )

    return new_service
mlflow_deployer_step(enable_cache=True, name=None)

Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.

The returned step can be used in a pipeline to implement continuous deployment for an MLflow model.

Parameters:

Name Type Description Default
enable_cache bool

Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default

True
name Optional[str]

Name of the step.

None

Returns:

Type Description
Type[zenml.steps.base_step.BaseStep]

an MLflow model deployer pipeline step

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
def mlflow_deployer_step(
    enable_cache: bool = True,
    name: Optional[str] = None,
) -> Type[BaseStep]:
    """Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.

    The returned step can be used in a pipeline to implement continuous
    deployment for an MLflow model.

    Args:
        enable_cache: Specify whether caching is enabled for this step. If no
            value is passed, caching is enabled by default
        name: Name of the step.

    Returns:
        an MLflow model deployer pipeline step
    """
    logger.warning(
        "The `mlflow_deployer_step` function is deprecated. Please "
        "use the built-in `mlflow_model_deployer_step` step instead."
    )
    return mlflow_model_deployer_step

neural_prophet special

Initialization of the Neural Prophet integration.

NeuralProphetIntegration (Integration)

Definition of NeuralProphet integration for ZenML.

Source code in zenml/integrations/neural_prophet/__init__.py
class NeuralProphetIntegration(Integration):
    """Definition of NeuralProphet integration for ZenML."""

    NAME = NEURAL_PROPHET
    REQUIREMENTS = ["neuralprophet>=0.3.2"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.neural_prophet import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/neural_prophet/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.neural_prophet import materializers  # noqa

materializers special

Initialization of the Neural Prophet materializer.

neural_prophet_materializer

Implementation of the Neural Prophet materializer.

NeuralProphetMaterializer (BaseMaterializer)

Materializer to read/write NeuralProphet models.

Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
class NeuralProphetMaterializer(BaseMaterializer):
    """Materializer to read/write NeuralProphet models."""

    ASSOCIATED_TYPES = (NeuralProphet,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
        """Reads and returns a NeuralProphet model.

        Args:
            data_type: A NeuralProphet model object.

        Returns:
            A loaded NeuralProphet model.
        """
        super().handle_input(data_type)
        return torch.load(  # type: ignore[no-untyped-call]
            os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        )  # noqa

    def handle_return(self, model: NeuralProphet) -> None:
        """Writes a NeuralProphet model.

        Args:
            model: A NeuralProphet model object.
        """
        super().handle_return(model)
        torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
handle_input(self, data_type)

Reads and returns a NeuralProphet model.

Parameters:

Name Type Description Default
data_type Type[Any]

A NeuralProphet model object.

required

Returns:

Type Description
NeuralProphet

A loaded NeuralProphet model.

Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
    """Reads and returns a NeuralProphet model.

    Args:
        data_type: A NeuralProphet model object.

    Returns:
        A loaded NeuralProphet model.
    """
    super().handle_input(data_type)
    return torch.load(  # type: ignore[no-untyped-call]
        os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    )  # noqa
handle_return(self, model)

Writes a NeuralProphet model.

Parameters:

Name Type Description Default
model NeuralProphet

A NeuralProphet model object.

required
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_return(self, model: NeuralProphet) -> None:
    """Writes a NeuralProphet model.

    Args:
        model: A NeuralProphet model object.
    """
    super().handle_return(model)
    torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))

plotly special

Initialization of the Plotly integration.

PlotlyIntegration (Integration)

Definition of Plotly integration for ZenML.

Source code in zenml/integrations/plotly/__init__.py
class PlotlyIntegration(Integration):
    """Definition of Plotly integration for ZenML."""

    NAME = PLOTLY
    REQUIREMENTS = ["plotly>=5.4.0"]

visualizers special

Initialization of the Plotly Visualizer.

pipeline_lineage_visualizer

Implementation of the Plotly Pipeline Lineage Visualizer.

PipelineLineageVisualizer (BasePipelineVisualizer)

Visualize the lineage of runs in a pipeline using plotly.

Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
class PipelineLineageVisualizer(BasePipelineVisualizer):
    """Visualize the lineage of runs in a pipeline using plotly."""

    @abstractmethod
    def visualize(
        self, object: PipelineView, *args: Any, **kwargs: Any
    ) -> Figure:
        """Creates a pipeline lineage diagram using plotly.

        Args:
            object: The pipeline view to visualize.
            *args: Additional arguments to pass to the visualization.
            **kwargs: Additional keyword arguments to pass to the visualization.

        Returns:
            A plotly figure.
        """
        logger.warning(
            "This integration is not completed yet. Results might be unexpected."
        )

        category_dict = {}
        dimensions = ["run"]
        for run in object.runs:
            category_dict[run.name] = {"run": run.name}
            for step in run.steps:
                category_dict[run.name].update(
                    {
                        step.entrypoint_name: str(step.id),
                    }
                )
                if step.entrypoint_name not in dimensions:
                    dimensions.append(f"{step.entrypoint_name}")

        category_df = pd.DataFrame.from_dict(category_dict, orient="index")

        category_df = category_df.reset_index()

        fig = px.parallel_categories(
            category_df,
            dimensions,
            color=None,
            labels="status",
        )

        fig.show()
        return fig
visualize(self, object, *args, **kwargs)

Creates a pipeline lineage diagram using plotly.

Parameters:

Name Type Description Default
object PipelineView

The pipeline view to visualize.

required
*args Any

Additional arguments to pass to the visualization.

()
**kwargs Any

Additional keyword arguments to pass to the visualization.

{}

Returns:

Type Description
Figure

A plotly figure.

Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
@abstractmethod
def visualize(
    self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
    """Creates a pipeline lineage diagram using plotly.

    Args:
        object: The pipeline view to visualize.
        *args: Additional arguments to pass to the visualization.
        **kwargs: Additional keyword arguments to pass to the visualization.

    Returns:
        A plotly figure.
    """
    logger.warning(
        "This integration is not completed yet. Results might be unexpected."
    )

    category_dict = {}
    dimensions = ["run"]
    for run in object.runs:
        category_dict[run.name] = {"run": run.name}
        for step in run.steps:
            category_dict[run.name].update(
                {
                    step.entrypoint_name: str(step.id),
                }
            )
            if step.entrypoint_name not in dimensions:
                dimensions.append(f"{step.entrypoint_name}")

    category_df = pd.DataFrame.from_dict(category_dict, orient="index")

    category_df = category_df.reset_index()

    fig = px.parallel_categories(
        category_df,
        dimensions,
        color=None,
        labels="status",
    )

    fig.show()
    return fig

pytorch special

Initialization of the PyTorch integration.

PytorchIntegration (Integration)

Definition of PyTorch integration for ZenML.

Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
    """Definition of PyTorch integration for ZenML."""

    NAME = PYTORCH
    REQUIREMENTS = ["torch"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch import materializers  # noqa

materializers special

Initialization of the PyTorch Materializer.

pytorch_dataloader_materializer

Implementation of the PyTorch DataLoader materializer.

PyTorchDataLoaderMaterializer (BaseMaterializer)

Materializer to read/write PyTorch dataloaders.

Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BaseMaterializer):
    """Materializer to read/write PyTorch dataloaders."""

    ASSOCIATED_TYPES = (DataLoader,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
        """Reads and returns a PyTorch dataloader.

        Args:
            data_type: The type of the dataloader to load.

        Returns:
            A loaded PyTorch dataloader.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
        ) as f:
            return cast(DataLoader[Any], torch.load(f))  # type: ignore[no-untyped-call]  # noqa

    def handle_return(self, dataloader: DataLoader[Any]) -> None:
        """Writes a PyTorch dataloader.

        Args:
            dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
        """
        super().handle_return(dataloader)

        # Save entire dataloader to artifact directory
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
        ) as f:
            torch.save(dataloader, f)
handle_input(self, data_type)

Reads and returns a PyTorch dataloader.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the dataloader to load.

required

Returns:

Type Description
torch.utils.data.dataloader.DataLoader[Any]

A loaded PyTorch dataloader.

Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
    """Reads and returns a PyTorch dataloader.

    Args:
        data_type: The type of the dataloader to load.

    Returns:
        A loaded PyTorch dataloader.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
    ) as f:
        return cast(DataLoader[Any], torch.load(f))  # type: ignore[no-untyped-call]  # noqa
handle_return(self, dataloader)

Writes a PyTorch dataloader.

Parameters:

Name Type Description Default
dataloader torch.utils.data.dataloader.DataLoader[Any]

A torch.utils.DataLoader or a dict to pass into dataloader.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_return(self, dataloader: DataLoader[Any]) -> None:
    """Writes a PyTorch dataloader.

    Args:
        dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
    """
    super().handle_return(dataloader)

    # Save entire dataloader to artifact directory
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
    ) as f:
        torch.save(dataloader, f)
pytorch_module_materializer

Implementation of the PyTorch Module materializer.

PyTorchModuleMaterializer (BaseMaterializer)

Materializer to read/write Pytorch models.

Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html

Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BaseMaterializer):
    """Materializer to read/write Pytorch models.

    Inspired by the guide:
    https://pytorch.org/tutorials/beginner/saving_loading_models.html
    """

    ASSOCIATED_TYPES = (Module,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Module:
        """Reads and returns a PyTorch model.

        Only loads the model, not the checkpoint.

        Args:
            data_type: The type of the model to load.

        Returns:
            A loaded pytorch model.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
        ) as f:
            return torch.load(f)  # type: ignore[no-untyped-call]  # noqa

    def handle_return(self, model: Module) -> None:
        """Writes a PyTorch model, as a model and a checkpoint.

        Args:
            model: A torch.nn.Module or a dict to pass into model.save
        """
        super().handle_return(model)

        # Save entire model to artifact directory, This is the default behavior
        # for loading model in development phase (training, evaluation)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
        ) as f:
            torch.save(model, f)

        # Also save model checkpoint to artifact directory,
        # This is the default behavior for loading model in production phase (inference)
        if isinstance(model, Module):
            with fileio.open(
                os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
            ) as f:
                torch.save(model.state_dict(), f)
handle_input(self, data_type)

Reads and returns a PyTorch model.

Only loads the model, not the checkpoint.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to load.

required

Returns:

Type Description
Module

A loaded pytorch model.

Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_input(self, data_type: Type[Any]) -> Module:
    """Reads and returns a PyTorch model.

    Only loads the model, not the checkpoint.

    Args:
        data_type: The type of the model to load.

    Returns:
        A loaded pytorch model.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
    ) as f:
        return torch.load(f)  # type: ignore[no-untyped-call]  # noqa
handle_return(self, model)

Writes a PyTorch model, as a model and a checkpoint.

Parameters:

Name Type Description Default
model Module

A torch.nn.Module or a dict to pass into model.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_return(self, model: Module) -> None:
    """Writes a PyTorch model, as a model and a checkpoint.

    Args:
        model: A torch.nn.Module or a dict to pass into model.save
    """
    super().handle_return(model)

    # Save entire model to artifact directory, This is the default behavior
    # for loading model in development phase (training, evaluation)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
    ) as f:
        torch.save(model, f)

    # Also save model checkpoint to artifact directory,
    # This is the default behavior for loading model in production phase (inference)
    if isinstance(model, Module):
        with fileio.open(
            os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
        ) as f:
            torch.save(model.state_dict(), f)

pytorch_lightning special

Initialization of the PyTorch Lightning integration.

PytorchLightningIntegration (Integration)

Definition of PyTorch Lightning integration for ZenML.

Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
    """Definition of PyTorch Lightning integration for ZenML."""

    NAME = PYTORCH_L
    REQUIREMENTS = ["pytorch_lightning"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch_lightning import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch_lightning import materializers  # noqa

materializers special

Initialization of the PyTorch Lightning Materializer.

pytorch_lightning_materializer

Implementation of the PyTorch Lightning Materializer.

PyTorchLightningMaterializer (BaseMaterializer)

Materializer to read/write PyTorch models.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
    """Materializer to read/write PyTorch models."""

    ASSOCIATED_TYPES = (Trainer,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Trainer:
        """Reads and returns a PyTorch Lightning trainer.

        Args:
            data_type: The type of the trainer to load.

        Returns:
            A PyTorch Lightning trainer object.
        """
        super().handle_input(data_type)
        return Trainer(
            resume_from_checkpoint=os.path.join(
                self.artifact.uri, CHECKPOINT_NAME
            )
        )

    def handle_return(self, trainer: Trainer) -> None:
        """Writes a PyTorch Lightning trainer.

        Args:
            trainer: A PyTorch Lightning trainer object.
        """
        super().handle_return(trainer)
        trainer.save_checkpoint(
            os.path.join(self.artifact.uri, CHECKPOINT_NAME)
        )
handle_input(self, data_type)

Reads and returns a PyTorch Lightning trainer.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the trainer to load.

required

Returns:

Type Description
Trainer

A PyTorch Lightning trainer object.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_input(self, data_type: Type[Any]) -> Trainer:
    """Reads and returns a PyTorch Lightning trainer.

    Args:
        data_type: The type of the trainer to load.

    Returns:
        A PyTorch Lightning trainer object.
    """
    super().handle_input(data_type)
    return Trainer(
        resume_from_checkpoint=os.path.join(
            self.artifact.uri, CHECKPOINT_NAME
        )
    )
handle_return(self, trainer)

Writes a PyTorch Lightning trainer.

Parameters:

Name Type Description Default
trainer Trainer

A PyTorch Lightning trainer object.

required
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_return(self, trainer: Trainer) -> None:
    """Writes a PyTorch Lightning trainer.

    Args:
        trainer: A PyTorch Lightning trainer object.
    """
    super().handle_return(trainer)
    trainer.save_checkpoint(
        os.path.join(self.artifact.uri, CHECKPOINT_NAME)
    )

registry

Implementation of a registry to track ZenML integrations.

IntegrationRegistry

Registry to keep track of ZenML Integrations.

Source code in zenml/integrations/registry.py
class IntegrationRegistry(object):
    """Registry to keep track of ZenML Integrations."""

    def __init__(self) -> None:
        """Initializing the integration registry."""
        self._integrations: Dict[str, Type["Integration"]] = {}

    @property
    def integrations(self) -> Dict[str, Type["Integration"]]:
        """Method to get integrations dictionary.

        Returns:
            A dict of integration key to type of `Integration`.
        """
        return self._integrations

    @integrations.setter
    def integrations(self, i: Any) -> None:
        """Setter method for the integrations property.

        Args:
            i: Value to set the integrations property to.

        Raises:
            IntegrationError: If you try to manually set the integrations property.
        """
        raise IntegrationError(
            "Please do not manually change the integrations within the "
            "registry. If you would like to register a new integration "
            "manually, please use "
            "`integration_registry.register_integration()`."
        )

    def register_integration(
        self, key: str, type_: Type["Integration"]
    ) -> None:
        """Method to register an integration with a given name.

        Args:
            key: Name of the integration.
            type_: Type of the integration.
        """
        self._integrations[key] = type_

    def activate_integrations(self) -> None:
        """Method to activate the integrations with are registered in the registry."""
        for name, integration in self._integrations.items():
            if integration.check_installation():
                integration.activate()
                logger.debug(f"Integration `{name}` is activated.")
            else:
                logger.debug(f"Integration `{name}` could not be activated.")

    @property
    def list_integration_names(self) -> List[str]:
        """Get a list of all possible integrations.

        Returns:
            A list of all possible integrations.
        """
        return [name for name in self._integrations]

    def select_integration_requirements(
        self, integration_name: Optional[str] = None
    ) -> List[str]:
        """Select the requirements for a given integration or all integrations.

        Args:
            integration_name: Name of the integration to check.

        Returns:
            List of requirements for the integration.

        Raises:
            KeyError: If the integration is not found.
        """
        if integration_name:
            if integration_name in self.list_integration_names:
                return self._integrations[integration_name].REQUIREMENTS
            else:
                raise KeyError(
                    f"Version {integration_name} does not exist. "
                    f"Currently the following integrations are implemented. "
                    f"{self.list_integration_names}"
                )
        else:
            return [
                requirement
                for name in self.list_integration_names
                for requirement in self._integrations[name].REQUIREMENTS
            ]

    def is_installed(self, integration_name: Optional[str] = None) -> bool:
        """Checks if all requirements for an integration are installed.

        Args:
            integration_name: Name of the integration to check.

        Returns:
            True if all requirements are installed, False otherwise.

        Raises:
            KeyError: If the integration is not found.
        """
        if integration_name in self.list_integration_names:
            return self._integrations[integration_name].check_installation()
        elif not integration_name:
            all_installed = [
                self._integrations[item].check_installation()
                for item in self.list_integration_names
            ]
            return all(all_installed)
        else:
            raise KeyError(
                f"Integration '{integration_name}' not found. "
                f"Currently the following integrations are available: "
                f"{self.list_integration_names}"
            )

    def get_installed_integrations(self) -> List[str]:
        """Returns list of installed integrations.

        Returns:
            List of installed integrations.
        """
        return [
            name
            for name, integration in integration_registry.integrations.items()
            if integration.check_installation()
        ]
integrations: Dict[str, Type[Integration]] property writable

Method to get integrations dictionary.

Returns:

Type Description
Dict[str, Type[Integration]]

A dict of integration key to type of Integration.

list_integration_names: List[str] property readonly

Get a list of all possible integrations.

Returns:

Type Description
List[str]

A list of all possible integrations.

__init__(self) special

Initializing the integration registry.

Source code in zenml/integrations/registry.py
def __init__(self) -> None:
    """Initializing the integration registry."""
    self._integrations: Dict[str, Type["Integration"]] = {}
activate_integrations(self)

Method to activate the integrations with are registered in the registry.

Source code in zenml/integrations/registry.py
def activate_integrations(self) -> None:
    """Method to activate the integrations with are registered in the registry."""
    for name, integration in self._integrations.items():
        if integration.check_installation():
            integration.activate()
            logger.debug(f"Integration `{name}` is activated.")
        else:
            logger.debug(f"Integration `{name}` could not be activated.")
get_installed_integrations(self)

Returns list of installed integrations.

Returns:

Type Description
List[str]

List of installed integrations.

Source code in zenml/integrations/registry.py
def get_installed_integrations(self) -> List[str]:
    """Returns list of installed integrations.

    Returns:
        List of installed integrations.
    """
    return [
        name
        for name, integration in integration_registry.integrations.items()
        if integration.check_installation()
    ]
is_installed(self, integration_name=None)

Checks if all requirements for an integration are installed.

Parameters:

Name Type Description Default
integration_name Optional[str]

Name of the integration to check.

None

Returns:

Type Description
bool

True if all requirements are installed, False otherwise.

Exceptions:

Type Description
KeyError

If the integration is not found.

Source code in zenml/integrations/registry.py
def is_installed(self, integration_name: Optional[str] = None) -> bool:
    """Checks if all requirements for an integration are installed.

    Args:
        integration_name: Name of the integration to check.

    Returns:
        True if all requirements are installed, False otherwise.

    Raises:
        KeyError: If the integration is not found.
    """
    if integration_name in self.list_integration_names:
        return self._integrations[integration_name].check_installation()
    elif not integration_name:
        all_installed = [
            self._integrations[item].check_installation()
            for item in self.list_integration_names
        ]
        return all(all_installed)
    else:
        raise KeyError(
            f"Integration '{integration_name}' not found. "
            f"Currently the following integrations are available: "
            f"{self.list_integration_names}"
        )
register_integration(self, key, type_)

Method to register an integration with a given name.

Parameters:

Name Type Description Default
key str

Name of the integration.

required
type_ Type[Integration]

Type of the integration.

required
Source code in zenml/integrations/registry.py
def register_integration(
    self, key: str, type_: Type["Integration"]
) -> None:
    """Method to register an integration with a given name.

    Args:
        key: Name of the integration.
        type_: Type of the integration.
    """
    self._integrations[key] = type_
select_integration_requirements(self, integration_name=None)

Select the requirements for a given integration or all integrations.

Parameters:

Name Type Description Default
integration_name Optional[str]

Name of the integration to check.

None

Returns:

Type Description
List[str]

List of requirements for the integration.

Exceptions:

Type Description
KeyError

If the integration is not found.

Source code in zenml/integrations/registry.py
def select_integration_requirements(
    self, integration_name: Optional[str] = None
) -> List[str]:
    """Select the requirements for a given integration or all integrations.

    Args:
        integration_name: Name of the integration to check.

    Returns:
        List of requirements for the integration.

    Raises:
        KeyError: If the integration is not found.
    """
    if integration_name:
        if integration_name in self.list_integration_names:
            return self._integrations[integration_name].REQUIREMENTS
        else:
            raise KeyError(
                f"Version {integration_name} does not exist. "
                f"Currently the following integrations are implemented. "
                f"{self.list_integration_names}"
            )
    else:
        return [
            requirement
            for name in self.list_integration_names
            for requirement in self._integrations[name].REQUIREMENTS
        ]

s3 special

Initialization of the S3 integration.

The S3 integration allows the use of cloud artifact stores and file operations on S3 buckets.

S3Integration (Integration)

Definition of S3 integration for ZenML.

Source code in zenml/integrations/s3/__init__.py
class S3Integration(Integration):
    """Definition of S3 integration for ZenML."""

    NAME = S3
    REQUIREMENTS = ["s3fs==2022.3.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the s3 integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=S3_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.s3.artifact_stores.S3ArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the s3 integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/s3/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the s3 integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=S3_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.s3.artifact_stores.S3ArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        )
    ]

artifact_stores special

Initialization of the S3 Artifact Store.

s3_artifact_store

Implementation of the S3 Artifact Store.

S3ArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for S3 based artifacts.

All attributes of this class except path will be passed to the s3fs.S3FileSystem initialization. See here for more information on how to use those configuration options to connect to any S3-compatible storage.

When you want to register an S3ArtifactStore from the CLI and need to pass client_kwargs, config_kwargs or s3_additional_kwargs, you should pass them as a json string:

zenml artifact-store register my_s3_store --type=s3 --path=s3://my_bucket     --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
class S3ArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for S3 based artifacts.

    All attributes of this class except `path` will be passed to the
    `s3fs.S3FileSystem` initialization. See
    [here](https://s3fs.readthedocs.io/en/latest/) for more information on how
    to use those configuration options to connect to any S3-compatible storage.

    When you want to register an S3ArtifactStore from the CLI and need to pass
    `client_kwargs`, `config_kwargs` or `s3_additional_kwargs`, you should pass
    them as a json string:
    ```
    zenml artifact-store register my_s3_store --type=s3 --path=s3://my_bucket \
    --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
    ```
    """

    key: Optional[str] = None
    secret: Optional[str] = None
    token: Optional[str] = None
    client_kwargs: Optional[Dict[str, Any]] = None
    config_kwargs: Optional[Dict[str, Any]] = None
    s3_additional_kwargs: Optional[Dict[str, Any]] = None
    _filesystem: Optional[s3fs.S3FileSystem] = None

    # Class variables
    FLAVOR: ClassVar[str] = S3_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"s3://"}

    @validator(
        "client_kwargs", "config_kwargs", "s3_additional_kwargs", pre=True
    )
    def _convert_json_string(
        cls, value: Union[None, str, Dict[str, Any]]
    ) -> Optional[Dict[str, Any]]:
        """Converts potential JSON strings passed via the CLI to dictionaries.

        Args:
            value: The value to convert.

        Returns:
            The converted value.

        Raises:
            TypeError: If the value is not a `str`, `Dict` or `None`.
            ValueError: If the value is an invalid json string or a json string
                that does not decode into a dictionary.
        """
        if isinstance(value, str):
            try:
                dict_ = json.loads(value)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid json string '{value}'") from e

            if not isinstance(dict_, Dict):
                raise ValueError(
                    f"Json string '{value}' did not decode into a dictionary."
                )

            return dict_
        elif isinstance(value, Dict) or value is None:
            return value
        else:
            raise TypeError(f"{value} is not a json string or a dictionary.")

    def _get_credentials(
        self,
    ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
        """Gets authentication credentials.

        If an authentication secret is configured, the secret values are
        returned. Otherwise we fallback to the plain text component attributes.

        Returns:
            Tuple (key, secret, token) of credentials used to authenticate with
            the S3 filesystem.
        """
        secret = self.get_authentication_secret(
            expected_schema_type=AWSSecretSchema
        )
        if secret:
            return (
                secret.aws_access_key_id,
                secret.aws_secret_access_key,
                secret.aws_session_token,
            )
        else:
            return self.key, self.secret, self.token

    @property
    def filesystem(self) -> s3fs.S3FileSystem:
        """The s3 filesystem to access this artifact store.

        Returns:
            The s3 filesystem.
        """
        if not self._filesystem:
            key, secret, token = self._get_credentials()

            self._filesystem = s3fs.S3FileSystem(
                key=key,
                secret=secret,
                token=token,
                client_kwargs=self.client_kwargs,
                config_kwargs=self.config_kwargs,
                s3_additional_kwargs=self.s3_additional_kwargs,
            )
        return self._filesystem

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )

        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
            in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path to list.

        Returns:
            A list of paths that are files in the given directory.
        """
        # remove s3 prefix if given, so we can remove the directory later as
        # this method is expected to only return filenames
        path = convert_to_str(path)
        if path.startswith("s3://"):
            path = path[5:]

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a file info dict returned by the S3 filesystem.

            Args:
                file_dict: A file info dict returned by the S3 filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["Key"])
            base_name = file_path[len(path) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
            # s3fs.listdir also returns the root directory, so we filter
            # it out here
            if _extract_basename(dict_)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path of the file to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: The path to get stat info for.

        Returns:
            A dictionary containing the stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
                directory path, a list of directories inside the current directory
                and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        for directory, subdirectories, files in self.filesystem.walk(path=top):
            yield f"s3://{directory}", subdirectories, files
filesystem: S3FileSystem property readonly

The s3 filesystem to access this artifact store.

Returns:

Type Description
S3FileSystem

The s3 filesystem.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )

    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
        in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that are files in the given directory.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path to list.

    Returns:
        A list of paths that are files in the given directory.
    """
    # remove s3 prefix if given, so we can remove the directory later as
    # this method is expected to only return filenames
    path = convert_to_str(path)
    if path.startswith("s3://"):
        path = path[5:]

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a file info dict returned by the S3 filesystem.

        Args:
            file_dict: A file info dict returned by the S3 filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["Key"])
        base_name = file_path[len(path) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
        # s3fs.listdir also returns the root directory, so we filter
        # it out here
        if _extract_basename(dict_)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the file to remove.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path of the file to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

A dictionary containing the stat info.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: The path to get stat info for.

    Returns:
        A dictionary containing the stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    for directory, subdirectories, files in self.filesystem.walk(path=top):
        yield f"s3://{directory}", subdirectories, files

scipy special

Initialization of the Scipy integration.

ScipyIntegration (Integration)

Definition of scipy integration for ZenML.

Source code in zenml/integrations/scipy/__init__.py
class ScipyIntegration(Integration):
    """Definition of scipy integration for ZenML."""

    NAME = SCIPY
    REQUIREMENTS = ["scipy"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.scipy import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/scipy/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.scipy import materializers  # noqa

materializers special

Initialization of the Scipy materializers.

sparse_materializer

Implementation of the Scipy Sparse Materializer.

SparseMaterializer (BaseMaterializer)

Materializer to read and write scipy sparse matrices.

Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
class SparseMaterializer(BaseMaterializer):
    """Materializer to read and write scipy sparse matrices."""

    ASSOCIATED_TYPES = (spmatrix,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> spmatrix:
        """Reads spmatrix from npz file.

        Args:
            data_type: The type of the spmatrix to load.

        Returns:
            A spmatrix object.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
        ) as f:
            mat = load_npz(f)
        return mat

    def handle_return(self, mat: spmatrix) -> None:
        """Writes a spmatrix to the artifact store as a npz file.

        Args:
            mat: The spmatrix to write.
        """
        super().handle_return(mat)
        with fileio.open(
            os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
        ) as f:
            save_npz(f, mat)
handle_input(self, data_type)

Reads spmatrix from npz file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the spmatrix to load.

required

Returns:

Type Description
spmatrix

A spmatrix object.

Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_input(self, data_type: Type[Any]) -> spmatrix:
    """Reads spmatrix from npz file.

    Args:
        data_type: The type of the spmatrix to load.

    Returns:
        A spmatrix object.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
    ) as f:
        mat = load_npz(f)
    return mat
handle_return(self, mat)

Writes a spmatrix to the artifact store as a npz file.

Parameters:

Name Type Description Default
mat spmatrix

The spmatrix to write.

required
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_return(self, mat: spmatrix) -> None:
    """Writes a spmatrix to the artifact store as a npz file.

    Args:
        mat: The spmatrix to write.
    """
    super().handle_return(mat)
    with fileio.open(
        os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
    ) as f:
        save_npz(f, mat)

seldon special

Initialization of the Seldon integration.

The Seldon Core integration allows you to use the Seldon Core model serving platform to implement continuous model deployment.

SeldonIntegration (Integration)

Definition of Seldon Core integration for ZenML.

Source code in zenml/integrations/seldon/__init__.py
class SeldonIntegration(Integration):
    """Definition of Seldon Core integration for ZenML."""

    NAME = SELDON
    REQUIREMENTS = [
        "kubernetes==18.20.0",
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the Seldon Core integration."""
        from zenml.integrations.seldon import secret_schemas  # noqa
        from zenml.integrations.seldon import services  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Seldon Core.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=SELDON_MODEL_DEPLOYER_FLAVOR,
                source="zenml.integrations.seldon.model_deployers.SeldonModelDeployer",
                type=StackComponentType.MODEL_DEPLOYER,
                integration=cls.NAME,
            )
        ]
activate() classmethod

Activate the Seldon Core integration.

Source code in zenml/integrations/seldon/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the Seldon Core integration."""
    from zenml.integrations.seldon import secret_schemas  # noqa
    from zenml.integrations.seldon import services  # noqa
flavors() classmethod

Declare the stack component flavors for the Seldon Core.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/seldon/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Seldon Core.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=SELDON_MODEL_DEPLOYER_FLAVOR,
            source="zenml.integrations.seldon.model_deployers.SeldonModelDeployer",
            type=StackComponentType.MODEL_DEPLOYER,
            integration=cls.NAME,
        )
    ]

model_deployers special

Initialization of the Seldon Model Deployer.

seldon_model_deployer

Implementation of the Seldon Model Deployer.

SeldonModelDeployer (BaseModelDeployer) pydantic-model

Seldon Core model deployer stack component implementation.

Attributes:

Name Type Description
kubernetes_context Optional[str]

the Kubernetes context to use to contact the remote Seldon Core installation. If not specified, the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod).

kubernetes_namespace Optional[str]

the Kubernetes namespace where the Seldon Core deployment servers are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod).

base_url str

the base URL of the Kubernetes ingress used to expose the Seldon Core deployment servers.

secret Optional[str]

the name of a ZenML secret containing the credentials used by Seldon Core storage initializers to authenticate to the Artifact Store (i.e. the storage backend where models are stored - see https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials).

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
class SeldonModelDeployer(BaseModelDeployer):
    """Seldon Core model deployer stack component implementation.

    Attributes:
        kubernetes_context: the Kubernetes context to use to contact the remote
            Seldon Core installation. If not specified, the current
            configuration is used. Depending on where the Seldon model deployer
            is being used, this can be either a locally active context or an
            in-cluster Kubernetes configuration (if running inside a pod).
        kubernetes_namespace: the Kubernetes namespace where the Seldon Core
            deployment servers are provisioned and managed by ZenML. If not
            specified, the namespace set in the current configuration is used.
            Depending on where the Seldon model deployer is being used, this can
            be either the current namespace configured in the locally active
            context or the namespace in the context of which the pod is running
            (if running inside a pod).
        base_url: the base URL of the Kubernetes ingress used to expose the
            Seldon Core deployment servers.
        secret: the name of a ZenML secret containing the credentials used by
            Seldon Core storage initializers to authenticate to the Artifact
            Store (i.e. the storage backend where models are stored - see
            https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials).
    """

    # Class Configuration
    FLAVOR: ClassVar[str] = SELDON_MODEL_DEPLOYER_FLAVOR

    kubernetes_context: Optional[str]
    kubernetes_namespace: Optional[str]
    base_url: str
    secret: Optional[str]

    # private attributes
    _client: Optional[SeldonClient] = None

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "SeldonDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information that might be relevant to the user.

        Args:
            service_instance: Instance of a SeldonDeploymentService

        Returns:
            Model server information.
        """
        return {
            "PREDICTION_URL": service_instance.prediction_url,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
        }

    @staticmethod
    def get_active_model_deployer() -> "SeldonModelDeployer":
        """Get the Seldon Core model deployer registered in the active stack.

        Returns:
            The Seldon Core model deployer registered in the active stack.

        Raises:
            TypeError: if the Seldon Core model deployer is not available.
        """
        model_deployer = Repository(  # type: ignore [call-arg]
            skip_repository_check=True
        ).active_stack.model_deployer
        if not model_deployer or not isinstance(
            model_deployer, SeldonModelDeployer
        ):
            raise TypeError(
                f"The active stack needs to have a Seldon Core model deployer "
                f"component registered to be able to deploy models with Seldon "
                f"Core. You can create a new stack with a Seldon Core model "
                f"deployer component or update your existing stack to add this "
                f"component, e.g.:\n\n"
                f"  'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
                f"--kubernetes_context=context-name --kubernetes_namespace="
                f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
                f"  'zenml stack create stack-name -d seldon ...'\n"
            )
        return model_deployer

    @property
    def seldon_client(self) -> SeldonClient:
        """Get the Seldon Core client associated with this model deployer.

        Returns:
            The Seldon Core client.
        """
        if not self._client:
            self._client = SeldonClient(
                context=self.kubernetes_context,
                namespace=self.kubernetes_namespace,
            )
        return self._client

    @property
    def kubernetes_secret_name(self) -> Optional[str]:
        """Get the Kubernetes secret name associated with this model deployer.

        If a secret is configured for this model deployer, a corresponding
        Kubernetes secret is created in the remote cluster to be used
        by Seldon Core storage initializers to authenticate to the Artifact
        Store. This method returns the unique name that is used for this secret.

        Returns:
            The Seldon Core Kubernetes secret name, or None if no secret is
            configured.
        """
        if not self.secret:
            return None
        return (
            re.sub(r"[^0-9a-zA-Z-]+", "-", f"zenml-seldon-core-{self.secret}")
            .strip("-")
            .lower()
        )

    def _create_or_update_kubernetes_secret(self) -> Optional[str]:
        """Create or update a Kubernetes secret.

        Uses the information stored in the ZenML secret configured for the model deployer.

        Returns:
            The name of the Kubernetes secret that was created or updated, or
            None if no secret was configured.

        Raises:
            RuntimeError: if the secret cannot be created or updated.
        """
        # if a ZenML secret was configured in the model deployer,
        # create a Kubernetes secret as a means to pass this information
        # to the Seldon Core deployment
        if self.secret:

            secret_manager = Repository(  # type: ignore [call-arg]
                skip_repository_check=True
            ).active_stack.secrets_manager

            if not secret_manager or not isinstance(
                secret_manager, BaseSecretsManager
            ):
                raise RuntimeError(
                    f"The active stack doesn't have a secret manager component. "
                    f"The ZenML secret specified in the Seldon Core Model "
                    f"Deployer configuration cannot be fetched: {self.secret}."
                )

            try:
                zenml_secret = secret_manager.get_secret(self.secret)
            except KeyError:
                raise RuntimeError(
                    f"The ZenML secret '{self.secret}' specified in the "
                    f"Seldon Core Model Deployer configuration was not found "
                    f"in the active stack's secret manager."
                )

            # should never happen, just making mypy happy
            assert self.kubernetes_secret_name is not None
            self.seldon_client.create_or_update_secret(
                self.kubernetes_secret_name, zenml_secret
            )

        return self.kubernetes_secret_name

    def _delete_kubernetes_secret(self) -> None:
        """Delete the Kubernetes secret associated with this model deployer.

        Do this if no Seldon Core deployments are using it.
        """
        if self.kubernetes_secret_name:

            # fetch all the Seldon Core deployments that currently
            # configured to use this secret
            services = self.find_model_server()
            for service in services:
                config = cast(SeldonDeploymentConfig, service.config)
                if config.secret_name == self.kubernetes_secret_name:
                    return
            self.seldon_client.delete_secret(self.kubernetes_secret_name)

    def deploy_model(
        self,
        config: ServiceConfig,
        replace: bool = False,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new Seldon Core deployment or update an existing one.

        # noqa: DAR402

        This should serve the supplied model and deployment configuration.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new Seldon
            Core deployment server to reflect the model and other configuration
            parameters specified in the supplied Seldon deployment `config`.

          * if `replace` is True, this method will first attempt to find an
            existing Seldon Core deployment that is *equivalent* to the supplied
            configuration parameters. Two or more Seldon Core deployments are
            considered equivalent if they have the same `pipeline_name`,
            `pipeline_step_name` and `model_name` configuration parameters. To
            put it differently, two Seldon Core deployments are equivalent if
            they serve versions of the same model deployed by the same pipeline
            step. If an equivalent Seldon Core deployment is found, it will be
            updated in place to reflect the new configuration parameters. This
            allows an existing Seldon Core deployment to retain its prediction
            URL while performing a rolling update to serve a new model version.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new Seldon Core deployment
        server for each new model version. If multiple equivalent Seldon Core
        deployments are found, the most recently created deployment is selected
        to be updated and the others are deleted.

        Args:
            config: the configuration of the model to be deployed with Seldon.
                Core
            replace: set this flag to True to find and update an equivalent
                Seldon Core deployment server with the new model instead of
                starting a new deployment server.
            timeout: the timeout in seconds to wait for the Seldon Core server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the Seldon Core
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML Seldon Core deployment service object that can be used to
            interact with the remote Seldon Core server.

        Raises:
            SeldonClientError: if a Seldon Core client error is encountered
                while provisioning the Seldon Core deployment server.
            RuntimeError: if `timeout` is set to a positive value that is
                exceeded while waiting for the Seldon Core deployment server
                to start, or if an operational failure is encountered before
                it reaches a ready state.
        """
        config = cast(SeldonDeploymentConfig, config)
        service = None

        # if a custom Kubernetes secret is not explicitly specified in the
        # SeldonDeploymentConfig, try to create one from the ZenML secret
        # configured for the model deployer
        config.secret_name = (
            config.secret_name or self._create_or_update_kubernetes_secret()
        )

        # if replace is True, find equivalent Seldon Core deployments
        if replace is True:
            equivalent_services = self.find_model_server(
                running=False,
                pipeline_name=config.pipeline_name,
                pipeline_step_name=config.pipeline_step_name,
                model_name=config.model_name,
            )

            for equivalent_service in equivalent_services:
                if service is None:
                    # keep the most recently created service
                    service = equivalent_service
                else:
                    try:
                        # delete the older services and don't wait for them to
                        # be deprovisioned
                        service.stop()
                    except RuntimeError:
                        # ignore errors encountered while stopping old services
                        pass

        if service:
            # update an equivalent service in place
            service.update(config)
            logger.info(
                f"Updating an existing Seldon deployment service: {service}"
            )
        else:
            # create a new service
            service = SeldonDeploymentService(config=config)
            logger.info(f"Creating a new Seldon deployment service: {service}")

        # start the service which in turn provisions the Seldon Core
        # deployment server and waits for it to reach a ready state
        service.start(timeout=timeout)
        return service

    def find_model_server(
        self,
        running: bool = False,
        service_uuid: Optional[UUID] = None,
        pipeline_name: Optional[str] = None,
        pipeline_run_id: Optional[str] = None,
        pipeline_step_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_type: Optional[str] = None,
    ) -> List[BaseService]:
        """Find one or more Seldon Core model services that match the given criteria.

        The Seldon Core deployment services that meet the search criteria are
        returned sorted in descending order of their creation time (i.e. more
        recent deployments first).

        Args:
            running: if true, only running services will be returned.
            service_uuid: the UUID of the Seldon Core service that was originally used
                to create the Seldon Core deployment resource.
            pipeline_name: name of the pipeline that the deployed model was part
                of.
            pipeline_run_id: ID of the pipeline run which the deployed model was
                part of.
            pipeline_step_name: the name of the pipeline model deployment step
                that deployed the model.
            model_name: the name of the deployed model.
            model_uri: URI of the deployed model.
            model_type: the Seldon Core server implementation used to serve
                the model

        Returns:
            One or more Seldon Core service objects representing Seldon Core
            model servers that match the input search criteria.
        """
        # Use a Seldon deployment service configuration to compute the labels
        config = SeldonDeploymentConfig(
            pipeline_name=pipeline_name or "",
            pipeline_run_id=pipeline_run_id or "",
            pipeline_step_name=pipeline_step_name or "",
            model_name=model_name or "",
            model_uri=model_uri or "",
            implementation=model_type or "",
        )
        labels = config.get_seldon_deployment_labels()
        if service_uuid:
            # the service UUID is not a label covered by the Seldon
            # deployment service configuration, so we need to add it
            # separately
            labels["zenml.service_uuid"] = str(service_uuid)

        deployments = self.seldon_client.find_deployments(labels=labels)
        # sort the deployments in descending order of their creation time
        deployments.sort(
            key=lambda deployment: datetime.strptime(
                deployment.metadata.creationTimestamp,
                "%Y-%m-%dT%H:%M:%SZ",
            )
            if deployment.metadata.creationTimestamp
            else datetime.min,
            reverse=True,
        )

        services: List[BaseService] = []
        for deployment in deployments:
            # recreate the Seldon deployment service object from the Seldon
            # deployment resource
            service = SeldonDeploymentService.create_from_deployment(
                deployment=deployment
            )
            if running and not service.is_running:
                # skip non-running services
                continue
            services.append(service)

        return services

    def stop_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Stop a Seldon Core model server.

        Args:
            uuid: UUID of the model server to stop.
            timeout: timeout in seconds to wait for the service to stop.
            force: if True, force the service to stop.

        Raises:
            NotImplementedError: stopping Seldon Core model servers is not
                supported.
        """
        raise NotImplementedError(
            "Stopping Seldon Core model servers is not implemented. Try "
            "deleting the Seldon Core model server instead."
        )

    def start_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    ) -> None:
        """Start a Seldon Core model deployment server.

        Args:
            uuid: UUID of the model server to start.
            timeout: timeout in seconds to wait for the service to become
                active. . If set to 0, the method will return immediately after
                provisioning the service, without waiting for it to become
                active.

        Raises:
            NotImplementedError: since we don't support starting Seldon Core
                model servers
        """
        raise NotImplementedError(
            "Starting Seldon Core model servers is not implemented"
        )

    def delete_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Delete a Seldon Core model deployment server.

        Args:
            uuid: UUID of the model server to delete.
            timeout: timeout in seconds to wait for the service to stop. If
                set to 0, the method will return immediately after
                deprovisioning the service, without waiting for it to stop.
            force: if True, force the service to stop.
        """
        services = self.find_model_server(service_uuid=uuid)
        if len(services) == 0:
            return
        services[0].stop(timeout=timeout, force=force)

        # if this is the last Seldon Core model server, delete the Kubernetes
        # secret used to store the authentication information for the Seldon
        # Core model server storage initializer
        self._delete_kubernetes_secret()
kubernetes_secret_name: Optional[str] property readonly

Get the Kubernetes secret name associated with this model deployer.

If a secret is configured for this model deployer, a corresponding Kubernetes secret is created in the remote cluster to be used by Seldon Core storage initializers to authenticate to the Artifact Store. This method returns the unique name that is used for this secret.

Returns:

Type Description
Optional[str]

The Seldon Core Kubernetes secret name, or None if no secret is configured.

seldon_client: SeldonClient property readonly

Get the Seldon Core client associated with this model deployer.

Returns:

Type Description
SeldonClient

The Seldon Core client.

delete_model_server(self, uuid, timeout=300, force=False)

Delete a Seldon Core model deployment server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to delete.

required
timeout int

timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop.

300
force bool

if True, force the service to stop.

False
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def delete_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Delete a Seldon Core model deployment server.

    Args:
        uuid: UUID of the model server to delete.
        timeout: timeout in seconds to wait for the service to stop. If
            set to 0, the method will return immediately after
            deprovisioning the service, without waiting for it to stop.
        force: if True, force the service to stop.
    """
    services = self.find_model_server(service_uuid=uuid)
    if len(services) == 0:
        return
    services[0].stop(timeout=timeout, force=force)

    # if this is the last Seldon Core model server, delete the Kubernetes
    # secret used to store the authentication information for the Seldon
    # Core model server storage initializer
    self._delete_kubernetes_secret()
deploy_model(self, config, replace=False, timeout=300)

Create a new Seldon Core deployment or update an existing one.

noqa: DAR402

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new Seldon Core deployment server to reflect the model and other configuration parameters specified in the supplied Seldon deployment config.

  • if replace is True, this method will first attempt to find an existing Seldon Core deployment that is equivalent to the supplied configuration parameters. Two or more Seldon Core deployments are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two Seldon Core deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent Seldon Core deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing Seldon Core deployment to retain its prediction URL while performing a rolling update to serve a new model version.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new Seldon Core deployment server for each new model version. If multiple equivalent Seldon Core deployments are found, the most recently created deployment is selected to be updated and the others are deleted.

Parameters:

Name Type Description Default
config ServiceConfig

the configuration of the model to be deployed with Seldon. Core

required
replace bool

set this flag to True to find and update an equivalent Seldon Core deployment server with the new model instead of starting a new deployment server.

False
timeout int

the timeout in seconds to wait for the Seldon Core server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Seldon Core server is provisioned, without waiting for it to fully start.

300

Returns:

Type Description
BaseService

The ZenML Seldon Core deployment service object that can be used to interact with the remote Seldon Core server.

Exceptions:

Type Description
SeldonClientError

if a Seldon Core client error is encountered while provisioning the Seldon Core deployment server.

RuntimeError

if timeout is set to a positive value that is exceeded while waiting for the Seldon Core deployment server to start, or if an operational failure is encountered before it reaches a ready state.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def deploy_model(
    self,
    config: ServiceConfig,
    replace: bool = False,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new Seldon Core deployment or update an existing one.

    # noqa: DAR402

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new Seldon
        Core deployment server to reflect the model and other configuration
        parameters specified in the supplied Seldon deployment `config`.

      * if `replace` is True, this method will first attempt to find an
        existing Seldon Core deployment that is *equivalent* to the supplied
        configuration parameters. Two or more Seldon Core deployments are
        considered equivalent if they have the same `pipeline_name`,
        `pipeline_step_name` and `model_name` configuration parameters. To
        put it differently, two Seldon Core deployments are equivalent if
        they serve versions of the same model deployed by the same pipeline
        step. If an equivalent Seldon Core deployment is found, it will be
        updated in place to reflect the new configuration parameters. This
        allows an existing Seldon Core deployment to retain its prediction
        URL while performing a rolling update to serve a new model version.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new Seldon Core deployment
    server for each new model version. If multiple equivalent Seldon Core
    deployments are found, the most recently created deployment is selected
    to be updated and the others are deleted.

    Args:
        config: the configuration of the model to be deployed with Seldon.
            Core
        replace: set this flag to True to find and update an equivalent
            Seldon Core deployment server with the new model instead of
            starting a new deployment server.
        timeout: the timeout in seconds to wait for the Seldon Core server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the Seldon Core
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML Seldon Core deployment service object that can be used to
        interact with the remote Seldon Core server.

    Raises:
        SeldonClientError: if a Seldon Core client error is encountered
            while provisioning the Seldon Core deployment server.
        RuntimeError: if `timeout` is set to a positive value that is
            exceeded while waiting for the Seldon Core deployment server
            to start, or if an operational failure is encountered before
            it reaches a ready state.
    """
    config = cast(SeldonDeploymentConfig, config)
    service = None

    # if a custom Kubernetes secret is not explicitly specified in the
    # SeldonDeploymentConfig, try to create one from the ZenML secret
    # configured for the model deployer
    config.secret_name = (
        config.secret_name or self._create_or_update_kubernetes_secret()
    )

    # if replace is True, find equivalent Seldon Core deployments
    if replace is True:
        equivalent_services = self.find_model_server(
            running=False,
            pipeline_name=config.pipeline_name,
            pipeline_step_name=config.pipeline_step_name,
            model_name=config.model_name,
        )

        for equivalent_service in equivalent_services:
            if service is None:
                # keep the most recently created service
                service = equivalent_service
            else:
                try:
                    # delete the older services and don't wait for them to
                    # be deprovisioned
                    service.stop()
                except RuntimeError:
                    # ignore errors encountered while stopping old services
                    pass

    if service:
        # update an equivalent service in place
        service.update(config)
        logger.info(
            f"Updating an existing Seldon deployment service: {service}"
        )
    else:
        # create a new service
        service = SeldonDeploymentService(config=config)
        logger.info(f"Creating a new Seldon deployment service: {service}")

    # start the service which in turn provisions the Seldon Core
    # deployment server and waits for it to reach a ready state
    service.start(timeout=timeout)
    return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)

Find one or more Seldon Core model services that match the given criteria.

The Seldon Core deployment services that meet the search criteria are returned sorted in descending order of their creation time (i.e. more recent deployments first).

Parameters:

Name Type Description Default
running bool

if true, only running services will be returned.

False
service_uuid Optional[uuid.UUID]

the UUID of the Seldon Core service that was originally used to create the Seldon Core deployment resource.

None
pipeline_name Optional[str]

name of the pipeline that the deployed model was part of.

None
pipeline_run_id Optional[str]

ID of the pipeline run which the deployed model was part of.

None
pipeline_step_name Optional[str]

the name of the pipeline model deployment step that deployed the model.

None
model_name Optional[str]

the name of the deployed model.

None
model_uri Optional[str]

URI of the deployed model.

None
model_type Optional[str]

the Seldon Core server implementation used to serve the model

None

Returns:

Type Description
List[zenml.services.service.BaseService]

One or more Seldon Core service objects representing Seldon Core model servers that match the input search criteria.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def find_model_server(
    self,
    running: bool = False,
    service_uuid: Optional[UUID] = None,
    pipeline_name: Optional[str] = None,
    pipeline_run_id: Optional[str] = None,
    pipeline_step_name: Optional[str] = None,
    model_name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_type: Optional[str] = None,
) -> List[BaseService]:
    """Find one or more Seldon Core model services that match the given criteria.

    The Seldon Core deployment services that meet the search criteria are
    returned sorted in descending order of their creation time (i.e. more
    recent deployments first).

    Args:
        running: if true, only running services will be returned.
        service_uuid: the UUID of the Seldon Core service that was originally used
            to create the Seldon Core deployment resource.
        pipeline_name: name of the pipeline that the deployed model was part
            of.
        pipeline_run_id: ID of the pipeline run which the deployed model was
            part of.
        pipeline_step_name: the name of the pipeline model deployment step
            that deployed the model.
        model_name: the name of the deployed model.
        model_uri: URI of the deployed model.
        model_type: the Seldon Core server implementation used to serve
            the model

    Returns:
        One or more Seldon Core service objects representing Seldon Core
        model servers that match the input search criteria.
    """
    # Use a Seldon deployment service configuration to compute the labels
    config = SeldonDeploymentConfig(
        pipeline_name=pipeline_name or "",
        pipeline_run_id=pipeline_run_id or "",
        pipeline_step_name=pipeline_step_name or "",
        model_name=model_name or "",
        model_uri=model_uri or "",
        implementation=model_type or "",
    )
    labels = config.get_seldon_deployment_labels()
    if service_uuid:
        # the service UUID is not a label covered by the Seldon
        # deployment service configuration, so we need to add it
        # separately
        labels["zenml.service_uuid"] = str(service_uuid)

    deployments = self.seldon_client.find_deployments(labels=labels)
    # sort the deployments in descending order of their creation time
    deployments.sort(
        key=lambda deployment: datetime.strptime(
            deployment.metadata.creationTimestamp,
            "%Y-%m-%dT%H:%M:%SZ",
        )
        if deployment.metadata.creationTimestamp
        else datetime.min,
        reverse=True,
    )

    services: List[BaseService] = []
    for deployment in deployments:
        # recreate the Seldon deployment service object from the Seldon
        # deployment resource
        service = SeldonDeploymentService.create_from_deployment(
            deployment=deployment
        )
        if running and not service.is_running:
            # skip non-running services
            continue
        services.append(service)

    return services
get_active_model_deployer() staticmethod

Get the Seldon Core model deployer registered in the active stack.

Returns:

Type Description
SeldonModelDeployer

The Seldon Core model deployer registered in the active stack.

Exceptions:

Type Description
TypeError

if the Seldon Core model deployer is not available.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "SeldonModelDeployer":
    """Get the Seldon Core model deployer registered in the active stack.

    Returns:
        The Seldon Core model deployer registered in the active stack.

    Raises:
        TypeError: if the Seldon Core model deployer is not available.
    """
    model_deployer = Repository(  # type: ignore [call-arg]
        skip_repository_check=True
    ).active_stack.model_deployer
    if not model_deployer or not isinstance(
        model_deployer, SeldonModelDeployer
    ):
        raise TypeError(
            f"The active stack needs to have a Seldon Core model deployer "
            f"component registered to be able to deploy models with Seldon "
            f"Core. You can create a new stack with a Seldon Core model "
            f"deployer component or update your existing stack to add this "
            f"component, e.g.:\n\n"
            f"  'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
            f"--kubernetes_context=context-name --kubernetes_namespace="
            f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
            f"  'zenml stack create stack-name -d seldon ...'\n"
        )
    return model_deployer
get_model_server_info(service_instance) staticmethod

Return implementation specific information that might be relevant to the user.

Parameters:

Name Type Description Default
service_instance SeldonDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

Model server information.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "SeldonDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information that might be relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        Model server information.
    """
    return {
        "PREDICTION_URL": service_instance.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
    }
start_model_server(self, uuid, timeout=300)

Start a Seldon Core model deployment server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to start.

required
timeout int

timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active.

300

Exceptions:

Type Description
NotImplementedError

since we don't support starting Seldon Core model servers

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def start_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
    """Start a Seldon Core model deployment server.

    Args:
        uuid: UUID of the model server to start.
        timeout: timeout in seconds to wait for the service to become
            active. . If set to 0, the method will return immediately after
            provisioning the service, without waiting for it to become
            active.

    Raises:
        NotImplementedError: since we don't support starting Seldon Core
            model servers
    """
    raise NotImplementedError(
        "Starting Seldon Core model servers is not implemented"
    )
stop_model_server(self, uuid, timeout=300, force=False)

Stop a Seldon Core model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to stop.

required
timeout int

timeout in seconds to wait for the service to stop.

300
force bool

if True, force the service to stop.

False

Exceptions:

Type Description
NotImplementedError

stopping Seldon Core model servers is not supported.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def stop_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Stop a Seldon Core model server.

    Args:
        uuid: UUID of the model server to stop.
        timeout: timeout in seconds to wait for the service to stop.
        force: if True, force the service to stop.

    Raises:
        NotImplementedError: stopping Seldon Core model servers is not
            supported.
    """
    raise NotImplementedError(
        "Stopping Seldon Core model servers is not implemented. Try "
        "deleting the Seldon Core model server instead."
    )

secret_schemas special

Initialization for the Seldon secret schemas.

These are secret schemas that can be used to authenticate Seldon to the Artifact Store used to store served ML models.

secret_schemas

Implementation for Seldon secret schemas.

SeldonAzureSecretSchema (BaseSecretSchema) pydantic-model

Seldon Azure Blob Storage credentials.

Based on: https://rclone.org/azureblob/

Attributes:

Name Type Description
rclone_config_azureblob_type Literal['azureblob']

the rclone config type. Must be set to "azureblob" for this schema.

rclone_config_azureblob_account Optional[str]

storage Account Name. Leave blank to use SAS URL or MSI.

rclone_config_azureblob_key Optional[str]

storage Account Key. Leave blank to use SAS URL or MSI.

rclone_config_azureblob_sas_url Optional[str]

SAS URL for container level access only. Leave blank if using account/key or MSI.

rclone_config_azureblob_use_msi bool

use a managed service identity to authenticate (only works in Azure).

Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonAzureSecretSchema(BaseSecretSchema):
    """Seldon Azure Blob Storage credentials.

    Based on: https://rclone.org/azureblob/

    Attributes:
        rclone_config_azureblob_type: the rclone config type. Must be set to
            "azureblob" for this schema.
        rclone_config_azureblob_account: storage Account Name. Leave blank to
            use SAS URL or MSI.
        rclone_config_azureblob_key: storage Account Key. Leave blank to
            use SAS URL or MSI.
        rclone_config_azureblob_sas_url: SAS URL for container level access
            only. Leave blank if using account/key or MSI.
        rclone_config_azureblob_use_msi: use a managed service identity to
            authenticate (only works in Azure).
    """

    TYPE: ClassVar[str] = SELDON_AZUREBLOB_SECRET_SCHEMA_TYPE

    rclone_config_azureblob_type: Literal["azureblob"] = "azureblob"
    rclone_config_azureblob_account: Optional[str]
    rclone_config_azureblob_key: Optional[str]
    rclone_config_azureblob_sas_url: Optional[str]
    rclone_config_azureblob_use_msi: bool = False
SeldonGSSecretSchema (BaseSecretSchema) pydantic-model

Seldon GCS credentials.

Based on: https://rclone.org/googlecloudstorage/

Attributes:

Name Type Description
rclone_config_gs_type Literal['google cloud storage']

the rclone config type. Must be set to "google cloud storage" for this schema.

rclone_config_gs_client_id Optional[str]

OAuth client id.

rclone_config_gs_client_secret Optional[str]

OAuth client secret.

rclone_config_gs_token Optional[str]

OAuth Access Token as a JSON blob.

rclone_config_gs_project_number Optional[str]

project number.

rclone_config_gs_service_account_credentials Optional[str]

service account credentials JSON blob.

rclone_config_gs_anonymous bool

access public buckets and objects without credentials. Set to True if you just want to download files and don't configure credentials.

rclone_config_gs_auth_url Optional[str]

auth server URL.

Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonGSSecretSchema(BaseSecretSchema):
    """Seldon GCS credentials.

    Based on: https://rclone.org/googlecloudstorage/

    Attributes:
        rclone_config_gs_type: the rclone config type. Must be set to "google
            cloud storage" for this schema.
        rclone_config_gs_client_id: OAuth client id.
        rclone_config_gs_client_secret: OAuth client secret.
        rclone_config_gs_token: OAuth Access Token as a JSON blob.
        rclone_config_gs_project_number: project number.
        rclone_config_gs_service_account_credentials: service account
            credentials JSON blob.
        rclone_config_gs_anonymous: access public buckets and objects without
            credentials. Set to True if you just want to download files and
            don't configure credentials.
        rclone_config_gs_auth_url: auth server URL.
    """

    TYPE: ClassVar[str] = SELDON_GS_SECRET_SCHEMA_TYPE

    rclone_config_gs_type: Literal[
        "google cloud storage"
    ] = "google cloud storage"
    rclone_config_gs_client_id: Optional[str]
    rclone_config_gs_client_secret: Optional[str]
    rclone_config_gs_project_number: Optional[str]
    rclone_config_gs_service_account_credentials: Optional[str]
    rclone_config_gs_anonymous: bool = False
    rclone_config_gs_token: Optional[str]
    rclone_config_gs_auth_url: Optional[str]
    rclone_config_gs_token_url: Optional[str]
SeldonS3SecretSchema (BaseSecretSchema) pydantic-model

Seldon S3 credentials.

Based on: https://rclone.org/s3/#amazon-s3

Attributes:

Name Type Description
rclone_config_s3_type Literal['s3']

the rclone config type. Must be set to "s3" for this schema.

rclone_config_s3_provider str

the S3 provider (e.g. aws, ceph, minio).

rclone_config_s3_env_auth bool

get AWS credentials from EC2/ECS meta data (i.e. with IAM roles configuration). Only applies if access_key_id and secret_access_key are blank.

rclone_config_s3_access_key_id Optional[str]

AWS Access Key ID.

rclone_config_s3_secret_access_key Optional[str]

AWS Secret Access Key.

rclone_config_s3_session_token Optional[str]

AWS Session Token.

rclone_config_s3_region Optional[str]

region to connect to.

rclone_config_s3_endpoint Optional[str]

S3 API endpoint.

Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonS3SecretSchema(BaseSecretSchema):
    """Seldon S3 credentials.

    Based on: https://rclone.org/s3/#amazon-s3

    Attributes:
        rclone_config_s3_type: the rclone config type. Must be set to "s3" for
            this schema.
        rclone_config_s3_provider: the S3 provider (e.g. aws, ceph, minio).
        rclone_config_s3_env_auth: get AWS credentials from EC2/ECS meta data
            (i.e. with IAM roles configuration). Only applies if access_key_id
            and secret_access_key are blank.
        rclone_config_s3_access_key_id: AWS Access Key ID.
        rclone_config_s3_secret_access_key: AWS Secret Access Key.
        rclone_config_s3_session_token: AWS Session Token.
        rclone_config_s3_region: region to connect to.
        rclone_config_s3_endpoint: S3 API endpoint.

    """

    TYPE: ClassVar[str] = SELDON_S3_SECRET_SCHEMA_TYPE

    rclone_config_s3_type: Literal["s3"] = "s3"
    rclone_config_s3_provider: str = "aws"
    rclone_config_s3_env_auth: bool = False
    rclone_config_s3_access_key_id: Optional[str]
    rclone_config_s3_secret_access_key: Optional[str]
    rclone_config_s3_session_token: Optional[str]
    rclone_config_s3_region: Optional[str]
    rclone_config_s3_endpoint: Optional[str]

seldon_client

Implementation of the Seldon client for ZenML.

SeldonClient

A client for interacting with Seldon Deployments.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClient:
    """A client for interacting with Seldon Deployments."""

    def __init__(self, context: Optional[str], namespace: Optional[str]):
        """Initialize a Seldon Core client.

        Args:
            context: the Kubernetes context to use.
            namespace: the Kubernetes namespace to use.
        """
        self._context = context
        self._namespace = namespace
        self._initialize_k8s_clients()

    def _initialize_k8s_clients(self) -> None:
        """Initialize the Kubernetes clients.

        Raises:
            SeldonClientError: if Kubernetes configuration could not be loaded
        """
        try:
            k8s_config.load_incluster_config()
            if not self._namespace:
                # load the namespace in the context of which the
                # current pod is running
                self._namespace = open(
                    "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
                ).read()
        except k8s_config.config_exception.ConfigException:
            if not self._namespace:
                raise SeldonClientError(
                    "The Kubernetes namespace must be explicitly "
                    "configured when running outside of a cluster."
                )
            try:
                k8s_config.load_kube_config(
                    context=self._context, persist_config=False
                )
            except k8s_config.config_exception.ConfigException as e:
                raise SeldonClientError(
                    "Could not load the Kubernetes configuration"
                ) from e
        self._core_api = k8s_client.CoreV1Api()
        self._custom_objects_api = k8s_client.CustomObjectsApi()

    @staticmethod
    def sanitize_labels(labels: Dict[str, str]) -> None:
        """Update the label values to be valid Kubernetes labels.

        See:
        https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

        Args:
            labels: the labels to sanitize.
        """
        for key, value in labels.items():
            # Kubernetes labels must be alphanumeric, no longer than
            # 63 characters, and must begin and end with an alphanumeric
            # character ([a-z0-9A-Z])
            labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
                "-_."
            )

    @property
    def namespace(self) -> str:
        """Returns the Kubernetes namespace in use by the client.

        Returns:
            The Kubernetes namespace in use by the client.

        Raises:
            RuntimeError: if the namespace has not been configured.
        """
        if not self._namespace:
            # shouldn't happen if the client is initialized, but we need to
            # appease the mypy type checker
            raise RuntimeError("The Kubernetes namespace is not configured")
        return self._namespace

    def create_deployment(
        self,
        deployment: SeldonDeployment,
        poll_timeout: int = 0,
    ) -> SeldonDeployment:
        """Create a Seldon Core deployment resource.

        Args:
            deployment: the Seldon Core deployment resource to create
            poll_timeout: the maximum time to wait for the deployment to become
                available or to fail. If set to 0, the function will return
                immediately without checking the deployment status. If a timeout
                occurs and the deployment is still pending creation, it will
                be returned anyway and no exception will be raised.

        Returns:
            the created Seldon Core deployment resource with updated status.

        Raises:
            SeldonDeploymentExistsError: if a deployment with the same name
                already exists.
            SeldonClientError: if an unknown error occurs during the creation of
                the deployment.
        """
        try:
            logger.debug(f"Creating SeldonDeployment resource: {deployment}")

            # mark the deployment as managed by ZenML, to differentiate
            # between deployments that are created by ZenML and those that
            # are not
            deployment.mark_as_managed_by_zenml()

            response = self._custom_objects_api.create_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                body=deployment.dict(exclude_none=True),
                _request_timeout=poll_timeout or None,
            )
            logger.debug("Seldon Core API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when creating SeldonDeployment resource: %s", str(e)
            )
            if e.status == 409:
                raise SeldonDeploymentExistsError(
                    f"A deployment with the name {deployment.name} "
                    f"already exists in namespace {self._namespace}"
                )
            raise SeldonClientError(
                "Exception when creating SeldonDeployment resource"
            ) from e

        created_deployment = self.get_deployment(name=deployment.name)

        while poll_timeout > 0 and created_deployment.is_pending():
            time.sleep(5)
            poll_timeout -= 5
            created_deployment = self.get_deployment(name=deployment.name)

        return created_deployment

    def delete_deployment(
        self,
        name: str,
        force: bool = False,
        poll_timeout: int = 0,
    ) -> None:
        """Delete a Seldon Core deployment resource managed by ZenML.

        Args:
            name: the name of the Seldon Core deployment resource to delete.
            force: if True, the deployment deletion will be forced (the graceful
                period will be set to zero).
            poll_timeout: the maximum time to wait for the deployment to be
                deleted. If set to 0, the function will return immediately
                without checking the deployment status. If a timeout
                occurs and the deployment still exists, this method will
                return and no exception will be raised.

        Raises:
            SeldonClientError: if an unknown error occurs during the deployment
                removal.
        """
        try:
            logger.debug(f"Deleting SeldonDeployment resource: {name}")

            # call `get_deployment` to check that the deployment exists
            # and is managed by ZenML. It will raise
            # a SeldonDeploymentNotFoundError otherwise
            self.get_deployment(name=name)

            response = self._custom_objects_api.delete_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                name=name,
                _request_timeout=poll_timeout or None,
                grace_period_seconds=0 if force else None,
            )
            logger.debug("Seldon Core API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when deleting SeldonDeployment resource %s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Exception when deleting SeldonDeployment resource {name}"
            ) from e

        while poll_timeout > 0:
            try:
                self.get_deployment(name=name)
            except SeldonDeploymentNotFoundError:
                return
            time.sleep(5)
            poll_timeout -= 5

    def update_deployment(
        self,
        deployment: SeldonDeployment,
        poll_timeout: int = 0,
    ) -> SeldonDeployment:
        """Update a Seldon Core deployment resource.

        Args:
            deployment: the Seldon Core deployment resource to update
            poll_timeout: the maximum time to wait for the deployment to become
                available or to fail. If set to 0, the function will return
                immediately without checking the deployment status. If a timeout
                occurs and the deployment is still pending creation, it will
                be returned anyway and no exception will be raised.

        Returns:
            the updated Seldon Core deployment resource with updated status.

        Raises:
            SeldonClientError: if an unknown error occurs while updating the
                deployment.
        """
        try:
            logger.debug(
                f"Updating SeldonDeployment resource: {deployment.name}"
            )

            # mark the deployment as managed by ZenML, to differentiate
            # between deployments that are created by ZenML and those that
            # are not
            deployment.mark_as_managed_by_zenml()

            # call `get_deployment` to check that the deployment exists
            # and is managed by ZenML. It will raise
            # a SeldonDeploymentNotFoundError otherwise
            self.get_deployment(name=deployment.name)

            response = self._custom_objects_api.patch_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                name=deployment.name,
                body=deployment.dict(exclude_none=True),
                _request_timeout=poll_timeout or None,
            )
            logger.debug("Seldon Core API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when updating SeldonDeployment resource: %s", str(e)
            )
            raise SeldonClientError(
                "Exception when creating SeldonDeployment resource"
            ) from e

        updated_deployment = self.get_deployment(name=deployment.name)

        while poll_timeout > 0 and updated_deployment.is_pending():
            time.sleep(5)
            poll_timeout -= 5
            updated_deployment = self.get_deployment(name=deployment.name)

        return updated_deployment

    def get_deployment(self, name: str) -> SeldonDeployment:
        """Get a ZenML managed Seldon Core deployment resource by name.

        Args:
            name: the name of the Seldon Core deployment resource to fetch.

        Returns:
            The Seldon Core deployment resource.

        Raises:
            SeldonDeploymentNotFoundError: if the deployment resource cannot
                be found or is not managed by ZenML.
            SeldonClientError: if an unknown error occurs while fetching
                the deployment.
        """
        try:
            logger.debug(f"Retrieving SeldonDeployment resource: {name}")

            response = self._custom_objects_api.get_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                name=name,
            )
            logger.debug("Seldon Core API response: %s", response)
            try:
                deployment = SeldonDeployment(**response)
            except ValidationError as e:
                logger.error(
                    "Invalid Seldon Core deployment resource: %s\n%s",
                    str(e),
                    str(response),
                )
                raise SeldonDeploymentNotFoundError(
                    f"SeldonDeployment resource {name} could not be parsed"
                )

            # Only Seldon deployments managed by ZenML are returned
            if not deployment.is_managed_by_zenml():
                raise SeldonDeploymentNotFoundError(
                    f"Seldon Deployment {name} is not managed by ZenML"
                )
            return deployment

        except k8s_client.rest.ApiException as e:
            if e.status == 404:
                raise SeldonDeploymentNotFoundError(
                    f"SeldonDeployment resource not found: {name}"
                ) from e
            logger.error(
                "Exception when fetching SeldonDeployment resource %s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Unexpected exception when fetching SeldonDeployment "
                f"resource: {name}"
            ) from e

    def find_deployments(
        self,
        name: Optional[str] = None,
        labels: Optional[Dict[str, str]] = None,
        fields: Optional[Dict[str, str]] = None,
    ) -> List[SeldonDeployment]:
        """Find all ZenML-managed Seldon Core deployment resources matching the given criteria.

        Args:
            name: optional name of the deployment resource to find.
            fields: optional selector to restrict the list of returned
                Seldon deployments by their fields. Defaults to everything.
            labels: optional selector to restrict the list of returned
                Seldon deployments by their labels. Defaults to everything.

        Returns:
            List of Seldon Core deployments that match the given criteria.

        Raises:
            SeldonClientError: if an unknown error occurs while fetching
                the deployments.
        """
        fields = fields or {}
        labels = labels or {}
        # always filter results to only include Seldon deployments managed
        # by ZenML
        labels["app"] = "zenml"
        if name:
            fields = {"metadata.name": name}
        field_selector = (
            ",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
        )
        label_selector = (
            ",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
        )
        try:

            logger.debug(
                f"Searching SeldonDeployment resources with label selector "
                f"'{labels or ''}' and field selector '{fields or ''}'"
            )
            response = self._custom_objects_api.list_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                field_selector=field_selector,
                label_selector=label_selector,
            )
            logger.debug(
                "Seldon Core API returned %s items", len(response["items"])
            )
            deployments = []
            for item in response.get("items") or []:
                try:
                    deployments.append(SeldonDeployment(**item))
                except ValidationError as e:
                    logger.error(
                        "Invalid Seldon Core deployment resource: %s\n%s",
                        str(e),
                        str(item),
                    )
            return deployments
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when searching SeldonDeployment resources with "
                "label selector '%s' and field selector '%s': %s",
                label_selector or "",
                field_selector or "",
            )
            raise SeldonClientError(
                f"Unexpected exception when searching SeldonDeployment "
                f"with labels '{labels or ''}' and field '{fields or ''}'"
            ) from e

    def get_deployment_logs(
        self,
        name: str,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Get the logs of a Seldon Core deployment resource.

        Args:
            name: the name of the Seldon Core deployment to get logs for.
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.

        Yields:
            The next log line.

        Raises:
            SeldonClientError: if an unknown error occurs while fetching
                the logs.
        """
        logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
        try:
            response = self._core_api.list_namespaced_pod(
                namespace=self._namespace,
                label_selector=f"seldon-deployment-id={name}",
            )
            logger.debug("Kubernetes API response: %s", response)
            pods = response.items
            if not pods:
                raise SeldonClientError(
                    f"The Seldon Core deployment {name} is not currently "
                    f"running: no Kubernetes pods associated with it were found"
                )
            pod = pods[0]
            pod_name = pod.metadata.name

            containers = [c.name for c in pod.spec.containers]
            init_containers = [c.name for c in pod.spec.init_containers]
            container_statuses = {
                c.name: c.started or c.restart_count
                for c in pod.status.container_statuses
            }

            container = "default"
            if container not in containers:
                container = containers[0]
            # some containers might not be running yet and have no logs to show,
            # so we need to filter them out
            if not container_statuses[container]:
                container = init_containers[0]

            logger.info(
                f"Retrieving logs for pod: `{pod_name}` and container "
                f"`{container}` in namespace `{self._namespace}`"
            )
            response = self._core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=self._namespace,
                container=container,
                follow=follow,
                tail_lines=tail,
                _preload_content=False,
            )
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when fetching logs for SeldonDeployment resource "
                "%s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Unexpected exception when fetching logs for SeldonDeployment "
                f"resource: {name}"
            ) from e

        try:
            while True:
                line = response.readline().decode("utf-8").rstrip("\n")
                if not line:
                    return
                stop = yield line
                if stop:
                    return
        finally:
            response.release_conn()

    def create_or_update_secret(
        self,
        name: str,
        secret: BaseSecretSchema,
    ) -> None:
        """Create or update a Kubernetes Secret resource.

        Uses the information contained in a ZenML secret.

        Args:
            name: the name of the Secret resource to create.
            secret: a ZenML secret with key-values that should be
                stored in the Secret resource.

        Raises:
            SeldonClientError: if an unknown error occurs during the creation of
                the secret.
            k8s_client.rest.ApiException: unexpected error.
        """
        try:
            logger.debug(f"Creating Secret resource: {name}")

            secret_data = {
                k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
                    "ascii"
                )
                for k, v in secret.content.items()
                if v is not None
            }

            secret = k8s_client.V1Secret(
                metadata=k8s_client.V1ObjectMeta(
                    name=name,
                    labels={"app": "zenml"},
                ),
                type="Opaque",
                data=secret_data,
            )

            try:
                # check if the secret is already present
                self._core_api.read_namespaced_secret(
                    name=name,
                    namespace=self._namespace,
                )
                # if we got this far, the secret is already present, update it
                # in place
                response = self._core_api.replace_namespaced_secret(
                    name=name,
                    namespace=self._namespace,
                    body=secret,
                )
            except k8s_client.rest.ApiException as e:
                if e.status != 404:
                    # if an error other than 404 is raised here, treat it
                    # as an unexpected error
                    raise
                response = self._core_api.create_namespaced_secret(
                    namespace=self._namespace,
                    body=secret,
                )
            logger.debug("Kubernetes API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error("Exception when creating Secret resource: %s", str(e))
            raise SeldonClientError(
                "Exception when creating Secret resource"
            ) from e

    def delete_secret(
        self,
        name: str,
    ) -> None:
        """Delete a Kubernetes Secret resource managed by ZenML.

        Args:
            name: the name of the Kubernetes Secret resource to delete.

        Raises:
            SeldonClientError: if an unknown error occurs during the removal
                of the secret.
        """
        try:
            logger.debug(f"Deleting Secret resource: {name}")

            response = self._core_api.delete_namespaced_secret(
                name=name,
                namespace=self._namespace,
            )
            logger.debug("Kubernetes API response: %s", response)
        except k8s_client.rest.ApiException as e:
            if e.status == 404:
                # the secret is no longer present, nothing to do
                return
            logger.error(
                "Exception when deleting Secret resource %s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Exception when deleting Secret resource {name}"
            ) from e
namespace: str property readonly

Returns the Kubernetes namespace in use by the client.

Returns:

Type Description
str

The Kubernetes namespace in use by the client.

Exceptions:

Type Description
RuntimeError

if the namespace has not been configured.

__init__(self, context, namespace) special

Initialize a Seldon Core client.

Parameters:

Name Type Description Default
context Optional[str]

the Kubernetes context to use.

required
namespace Optional[str]

the Kubernetes namespace to use.

required
Source code in zenml/integrations/seldon/seldon_client.py
def __init__(self, context: Optional[str], namespace: Optional[str]):
    """Initialize a Seldon Core client.

    Args:
        context: the Kubernetes context to use.
        namespace: the Kubernetes namespace to use.
    """
    self._context = context
    self._namespace = namespace
    self._initialize_k8s_clients()
create_deployment(self, deployment, poll_timeout=0)

Create a Seldon Core deployment resource.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource to create

required
poll_timeout int

the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised.

0

Returns:

Type Description
SeldonDeployment

the created Seldon Core deployment resource with updated status.

Exceptions:

Type Description
SeldonDeploymentExistsError

if a deployment with the same name already exists.

SeldonClientError

if an unknown error occurs during the creation of the deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def create_deployment(
    self,
    deployment: SeldonDeployment,
    poll_timeout: int = 0,
) -> SeldonDeployment:
    """Create a Seldon Core deployment resource.

    Args:
        deployment: the Seldon Core deployment resource to create
        poll_timeout: the maximum time to wait for the deployment to become
            available or to fail. If set to 0, the function will return
            immediately without checking the deployment status. If a timeout
            occurs and the deployment is still pending creation, it will
            be returned anyway and no exception will be raised.

    Returns:
        the created Seldon Core deployment resource with updated status.

    Raises:
        SeldonDeploymentExistsError: if a deployment with the same name
            already exists.
        SeldonClientError: if an unknown error occurs during the creation of
            the deployment.
    """
    try:
        logger.debug(f"Creating SeldonDeployment resource: {deployment}")

        # mark the deployment as managed by ZenML, to differentiate
        # between deployments that are created by ZenML and those that
        # are not
        deployment.mark_as_managed_by_zenml()

        response = self._custom_objects_api.create_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            body=deployment.dict(exclude_none=True),
            _request_timeout=poll_timeout or None,
        )
        logger.debug("Seldon Core API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when creating SeldonDeployment resource: %s", str(e)
        )
        if e.status == 409:
            raise SeldonDeploymentExistsError(
                f"A deployment with the name {deployment.name} "
                f"already exists in namespace {self._namespace}"
            )
        raise SeldonClientError(
            "Exception when creating SeldonDeployment resource"
        ) from e

    created_deployment = self.get_deployment(name=deployment.name)

    while poll_timeout > 0 and created_deployment.is_pending():
        time.sleep(5)
        poll_timeout -= 5
        created_deployment = self.get_deployment(name=deployment.name)

    return created_deployment
create_or_update_secret(self, name, secret)

Create or update a Kubernetes Secret resource.

Uses the information contained in a ZenML secret.

Parameters:

Name Type Description Default
name str

the name of the Secret resource to create.

required
secret BaseSecretSchema

a ZenML secret with key-values that should be stored in the Secret resource.

required

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs during the creation of the secret.

k8s_client.rest.ApiException

unexpected error.

Source code in zenml/integrations/seldon/seldon_client.py
def create_or_update_secret(
    self,
    name: str,
    secret: BaseSecretSchema,
) -> None:
    """Create or update a Kubernetes Secret resource.

    Uses the information contained in a ZenML secret.

    Args:
        name: the name of the Secret resource to create.
        secret: a ZenML secret with key-values that should be
            stored in the Secret resource.

    Raises:
        SeldonClientError: if an unknown error occurs during the creation of
            the secret.
        k8s_client.rest.ApiException: unexpected error.
    """
    try:
        logger.debug(f"Creating Secret resource: {name}")

        secret_data = {
            k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
                "ascii"
            )
            for k, v in secret.content.items()
            if v is not None
        }

        secret = k8s_client.V1Secret(
            metadata=k8s_client.V1ObjectMeta(
                name=name,
                labels={"app": "zenml"},
            ),
            type="Opaque",
            data=secret_data,
        )

        try:
            # check if the secret is already present
            self._core_api.read_namespaced_secret(
                name=name,
                namespace=self._namespace,
            )
            # if we got this far, the secret is already present, update it
            # in place
            response = self._core_api.replace_namespaced_secret(
                name=name,
                namespace=self._namespace,
                body=secret,
            )
        except k8s_client.rest.ApiException as e:
            if e.status != 404:
                # if an error other than 404 is raised here, treat it
                # as an unexpected error
                raise
            response = self._core_api.create_namespaced_secret(
                namespace=self._namespace,
                body=secret,
            )
        logger.debug("Kubernetes API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error("Exception when creating Secret resource: %s", str(e))
        raise SeldonClientError(
            "Exception when creating Secret resource"
        ) from e
delete_deployment(self, name, force=False, poll_timeout=0)

Delete a Seldon Core deployment resource managed by ZenML.

Parameters:

Name Type Description Default
name str

the name of the Seldon Core deployment resource to delete.

required
force bool

if True, the deployment deletion will be forced (the graceful period will be set to zero).

False
poll_timeout int

the maximum time to wait for the deployment to be deleted. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment still exists, this method will return and no exception will be raised.

0

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs during the deployment removal.

Source code in zenml/integrations/seldon/seldon_client.py
def delete_deployment(
    self,
    name: str,
    force: bool = False,
    poll_timeout: int = 0,
) -> None:
    """Delete a Seldon Core deployment resource managed by ZenML.

    Args:
        name: the name of the Seldon Core deployment resource to delete.
        force: if True, the deployment deletion will be forced (the graceful
            period will be set to zero).
        poll_timeout: the maximum time to wait for the deployment to be
            deleted. If set to 0, the function will return immediately
            without checking the deployment status. If a timeout
            occurs and the deployment still exists, this method will
            return and no exception will be raised.

    Raises:
        SeldonClientError: if an unknown error occurs during the deployment
            removal.
    """
    try:
        logger.debug(f"Deleting SeldonDeployment resource: {name}")

        # call `get_deployment` to check that the deployment exists
        # and is managed by ZenML. It will raise
        # a SeldonDeploymentNotFoundError otherwise
        self.get_deployment(name=name)

        response = self._custom_objects_api.delete_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            name=name,
            _request_timeout=poll_timeout or None,
            grace_period_seconds=0 if force else None,
        )
        logger.debug("Seldon Core API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when deleting SeldonDeployment resource %s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Exception when deleting SeldonDeployment resource {name}"
        ) from e

    while poll_timeout > 0:
        try:
            self.get_deployment(name=name)
        except SeldonDeploymentNotFoundError:
            return
        time.sleep(5)
        poll_timeout -= 5
delete_secret(self, name)

Delete a Kubernetes Secret resource managed by ZenML.

Parameters:

Name Type Description Default
name str

the name of the Kubernetes Secret resource to delete.

required

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs during the removal of the secret.

Source code in zenml/integrations/seldon/seldon_client.py
def delete_secret(
    self,
    name: str,
) -> None:
    """Delete a Kubernetes Secret resource managed by ZenML.

    Args:
        name: the name of the Kubernetes Secret resource to delete.

    Raises:
        SeldonClientError: if an unknown error occurs during the removal
            of the secret.
    """
    try:
        logger.debug(f"Deleting Secret resource: {name}")

        response = self._core_api.delete_namespaced_secret(
            name=name,
            namespace=self._namespace,
        )
        logger.debug("Kubernetes API response: %s", response)
    except k8s_client.rest.ApiException as e:
        if e.status == 404:
            # the secret is no longer present, nothing to do
            return
        logger.error(
            "Exception when deleting Secret resource %s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Exception when deleting Secret resource {name}"
        ) from e
find_deployments(self, name=None, labels=None, fields=None)

Find all ZenML-managed Seldon Core deployment resources matching the given criteria.

Parameters:

Name Type Description Default
name Optional[str]

optional name of the deployment resource to find.

None
fields Optional[Dict[str, str]]

optional selector to restrict the list of returned Seldon deployments by their fields. Defaults to everything.

None
labels Optional[Dict[str, str]]

optional selector to restrict the list of returned Seldon deployments by their labels. Defaults to everything.

None

Returns:

Type Description
List[zenml.integrations.seldon.seldon_client.SeldonDeployment]

List of Seldon Core deployments that match the given criteria.

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs while fetching the deployments.

Source code in zenml/integrations/seldon/seldon_client.py
def find_deployments(
    self,
    name: Optional[str] = None,
    labels: Optional[Dict[str, str]] = None,
    fields: Optional[Dict[str, str]] = None,
) -> List[SeldonDeployment]:
    """Find all ZenML-managed Seldon Core deployment resources matching the given criteria.

    Args:
        name: optional name of the deployment resource to find.
        fields: optional selector to restrict the list of returned
            Seldon deployments by their fields. Defaults to everything.
        labels: optional selector to restrict the list of returned
            Seldon deployments by their labels. Defaults to everything.

    Returns:
        List of Seldon Core deployments that match the given criteria.

    Raises:
        SeldonClientError: if an unknown error occurs while fetching
            the deployments.
    """
    fields = fields or {}
    labels = labels or {}
    # always filter results to only include Seldon deployments managed
    # by ZenML
    labels["app"] = "zenml"
    if name:
        fields = {"metadata.name": name}
    field_selector = (
        ",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
    )
    label_selector = (
        ",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
    )
    try:

        logger.debug(
            f"Searching SeldonDeployment resources with label selector "
            f"'{labels or ''}' and field selector '{fields or ''}'"
        )
        response = self._custom_objects_api.list_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            field_selector=field_selector,
            label_selector=label_selector,
        )
        logger.debug(
            "Seldon Core API returned %s items", len(response["items"])
        )
        deployments = []
        for item in response.get("items") or []:
            try:
                deployments.append(SeldonDeployment(**item))
            except ValidationError as e:
                logger.error(
                    "Invalid Seldon Core deployment resource: %s\n%s",
                    str(e),
                    str(item),
                )
        return deployments
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when searching SeldonDeployment resources with "
            "label selector '%s' and field selector '%s': %s",
            label_selector or "",
            field_selector or "",
        )
        raise SeldonClientError(
            f"Unexpected exception when searching SeldonDeployment "
            f"with labels '{labels or ''}' and field '{fields or ''}'"
        ) from e
get_deployment(self, name)

Get a ZenML managed Seldon Core deployment resource by name.

Parameters:

Name Type Description Default
name str

the name of the Seldon Core deployment resource to fetch.

required

Returns:

Type Description
SeldonDeployment

The Seldon Core deployment resource.

Exceptions:

Type Description
SeldonDeploymentNotFoundError

if the deployment resource cannot be found or is not managed by ZenML.

SeldonClientError

if an unknown error occurs while fetching the deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment(self, name: str) -> SeldonDeployment:
    """Get a ZenML managed Seldon Core deployment resource by name.

    Args:
        name: the name of the Seldon Core deployment resource to fetch.

    Returns:
        The Seldon Core deployment resource.

    Raises:
        SeldonDeploymentNotFoundError: if the deployment resource cannot
            be found or is not managed by ZenML.
        SeldonClientError: if an unknown error occurs while fetching
            the deployment.
    """
    try:
        logger.debug(f"Retrieving SeldonDeployment resource: {name}")

        response = self._custom_objects_api.get_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            name=name,
        )
        logger.debug("Seldon Core API response: %s", response)
        try:
            deployment = SeldonDeployment(**response)
        except ValidationError as e:
            logger.error(
                "Invalid Seldon Core deployment resource: %s\n%s",
                str(e),
                str(response),
            )
            raise SeldonDeploymentNotFoundError(
                f"SeldonDeployment resource {name} could not be parsed"
            )

        # Only Seldon deployments managed by ZenML are returned
        if not deployment.is_managed_by_zenml():
            raise SeldonDeploymentNotFoundError(
                f"Seldon Deployment {name} is not managed by ZenML"
            )
        return deployment

    except k8s_client.rest.ApiException as e:
        if e.status == 404:
            raise SeldonDeploymentNotFoundError(
                f"SeldonDeployment resource not found: {name}"
            ) from e
        logger.error(
            "Exception when fetching SeldonDeployment resource %s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Unexpected exception when fetching SeldonDeployment "
            f"resource: {name}"
        ) from e
get_deployment_logs(self, name, follow=False, tail=None)

Get the logs of a Seldon Core deployment resource.

Parameters:

Name Type Description Default
name str

the name of the Seldon Core deployment to get logs for.

required
follow bool

if True, the logs will be streamed as they are written

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Yields:

Type Description
Generator[str, bool, NoneType]

The next log line.

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs while fetching the logs.

Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment_logs(
    self,
    name: str,
    follow: bool = False,
    tail: Optional[int] = None,
) -> Generator[str, bool, None]:
    """Get the logs of a Seldon Core deployment resource.

    Args:
        name: the name of the Seldon Core deployment to get logs for.
        follow: if True, the logs will be streamed as they are written
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.

    Yields:
        The next log line.

    Raises:
        SeldonClientError: if an unknown error occurs while fetching
            the logs.
    """
    logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
    try:
        response = self._core_api.list_namespaced_pod(
            namespace=self._namespace,
            label_selector=f"seldon-deployment-id={name}",
        )
        logger.debug("Kubernetes API response: %s", response)
        pods = response.items
        if not pods:
            raise SeldonClientError(
                f"The Seldon Core deployment {name} is not currently "
                f"running: no Kubernetes pods associated with it were found"
            )
        pod = pods[0]
        pod_name = pod.metadata.name

        containers = [c.name for c in pod.spec.containers]
        init_containers = [c.name for c in pod.spec.init_containers]
        container_statuses = {
            c.name: c.started or c.restart_count
            for c in pod.status.container_statuses
        }

        container = "default"
        if container not in containers:
            container = containers[0]
        # some containers might not be running yet and have no logs to show,
        # so we need to filter them out
        if not container_statuses[container]:
            container = init_containers[0]

        logger.info(
            f"Retrieving logs for pod: `{pod_name}` and container "
            f"`{container}` in namespace `{self._namespace}`"
        )
        response = self._core_api.read_namespaced_pod_log(
            name=pod_name,
            namespace=self._namespace,
            container=container,
            follow=follow,
            tail_lines=tail,
            _preload_content=False,
        )
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when fetching logs for SeldonDeployment resource "
            "%s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Unexpected exception when fetching logs for SeldonDeployment "
            f"resource: {name}"
        ) from e

    try:
        while True:
            line = response.readline().decode("utf-8").rstrip("\n")
            if not line:
                return
            stop = yield line
            if stop:
                return
    finally:
        response.release_conn()
sanitize_labels(labels) staticmethod

Update the label values to be valid Kubernetes labels.

See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

Parameters:

Name Type Description Default
labels Dict[str, str]

the labels to sanitize.

required
Source code in zenml/integrations/seldon/seldon_client.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
    """Update the label values to be valid Kubernetes labels.

    See:
    https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

    Args:
        labels: the labels to sanitize.
    """
    for key, value in labels.items():
        # Kubernetes labels must be alphanumeric, no longer than
        # 63 characters, and must begin and end with an alphanumeric
        # character ([a-z0-9A-Z])
        labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
            "-_."
        )
update_deployment(self, deployment, poll_timeout=0)

Update a Seldon Core deployment resource.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource to update

required
poll_timeout int

the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised.

0

Returns:

Type Description
SeldonDeployment

the updated Seldon Core deployment resource with updated status.

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs while updating the deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def update_deployment(
    self,
    deployment: SeldonDeployment,
    poll_timeout: int = 0,
) -> SeldonDeployment:
    """Update a Seldon Core deployment resource.

    Args:
        deployment: the Seldon Core deployment resource to update
        poll_timeout: the maximum time to wait for the deployment to become
            available or to fail. If set to 0, the function will return
            immediately without checking the deployment status. If a timeout
            occurs and the deployment is still pending creation, it will
            be returned anyway and no exception will be raised.

    Returns:
        the updated Seldon Core deployment resource with updated status.

    Raises:
        SeldonClientError: if an unknown error occurs while updating the
            deployment.
    """
    try:
        logger.debug(
            f"Updating SeldonDeployment resource: {deployment.name}"
        )

        # mark the deployment as managed by ZenML, to differentiate
        # between deployments that are created by ZenML and those that
        # are not
        deployment.mark_as_managed_by_zenml()

        # call `get_deployment` to check that the deployment exists
        # and is managed by ZenML. It will raise
        # a SeldonDeploymentNotFoundError otherwise
        self.get_deployment(name=deployment.name)

        response = self._custom_objects_api.patch_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            name=deployment.name,
            body=deployment.dict(exclude_none=True),
            _request_timeout=poll_timeout or None,
        )
        logger.debug("Seldon Core API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when updating SeldonDeployment resource: %s", str(e)
        )
        raise SeldonClientError(
            "Exception when creating SeldonDeployment resource"
        ) from e

    updated_deployment = self.get_deployment(name=deployment.name)

    while poll_timeout > 0 and updated_deployment.is_pending():
        time.sleep(5)
        poll_timeout -= 5
        updated_deployment = self.get_deployment(name=deployment.name)

    return updated_deployment
SeldonClientError (Exception)

Base exception class for all exceptions raised by the SeldonClient.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientError(Exception):
    """Base exception class for all exceptions raised by the SeldonClient."""
SeldonClientTimeout (SeldonClientError)

Raised when the Seldon client timed out while waiting for a resource to reach the expected status.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientTimeout(SeldonClientError):
    """Raised when the Seldon client timed out while waiting for a resource to reach the expected status."""
SeldonDeployment (BaseModel) pydantic-model

A Seldon Core deployment CRD.

This is a Pydantic representation of some of the fields in the Seldon Core CRD (documented here: https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).

Note that not all fields are represented, only those that are relevant to the ZenML integration. The fields that are not represented are silently ignored when the Seldon Deployment is created or updated from an external SeldonDeployment CRD representation.

Attributes:

Name Type Description
kind str

Kubernetes kind field.

apiVersion str

Kubernetes apiVersion field.

metadata SeldonDeploymentMetadata

Kubernetes metadata field.

spec SeldonDeploymentSpec

Seldon Deployment spec entry.

status Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatus]

Seldon Deployment status.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeployment(BaseModel):
    """A Seldon Core deployment CRD.

    This is a Pydantic representation of some of the fields in the Seldon Core
    CRD (documented here:
    https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).

    Note that not all fields are represented, only those that are relevant to
    the ZenML integration. The fields that are not represented are silently
    ignored when the Seldon Deployment is created or updated from an external
    SeldonDeployment CRD representation.

    Attributes:
        kind: Kubernetes kind field.
        apiVersion: Kubernetes apiVersion field.
        metadata: Kubernetes metadata field.
        spec: Seldon Deployment spec entry.
        status: Seldon Deployment status.
    """

    kind: str = Field(SELDON_DEPLOYMENT_KIND, const=True)
    apiVersion: str = Field(SELDON_DEPLOYMENT_API_VERSION, const=True)
    metadata: SeldonDeploymentMetadata = Field(
        default_factory=SeldonDeploymentMetadata
    )
    spec: SeldonDeploymentSpec = Field(default_factory=SeldonDeploymentSpec)
    status: Optional[SeldonDeploymentStatus]

    def __str__(self) -> str:
        """Returns a string representation of the Seldon Deployment.

        Returns:
            A string representation of the Seldon Deployment.
        """
        return json.dumps(self.dict(exclude_none=True), indent=4)

    @classmethod
    def build(
        cls,
        name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_name: Optional[str] = None,
        implementation: Optional[str] = None,
        secret_name: Optional[str] = None,
        labels: Optional[Dict[str, str]] = None,
        annotations: Optional[Dict[str, str]] = None,
    ) -> "SeldonDeployment":
        """Build a basic Seldon Deployment object.

        Args:
            name: The name of the Seldon Deployment. If not explicitly passed,
                a unique name is autogenerated.
            model_uri: The URI of the model.
            model_name: The name of the model.
            implementation: The implementation of the model.
            secret_name: The name of the Kubernetes secret containing
                environment variable values (e.g. with credentials for the
                artifact store) to use with the deployment service.
            labels: A dictionary of labels to apply to the Seldon Deployment.
            annotations: A dictionary of annotations to apply to the Seldon
                Deployment.

        Returns:
            A minimal SeldonDeployment object built from the provided
            parameters.
        """
        if not name:
            name = f"zenml-{time.time()}"

        if labels is None:
            labels = {}
        if annotations is None:
            annotations = {}

        return SeldonDeployment(
            metadata=SeldonDeploymentMetadata(
                name=name, labels=labels, annotations=annotations
            ),
            spec=SeldonDeploymentSpec(
                name=name,
                predictors=[
                    SeldonDeploymentPredictor(
                        name=model_name or "",
                        graph=SeldonDeploymentPredictiveUnit(
                            name="default",
                            type=SeldonDeploymentPredictiveUnitType.MODEL,
                            modelUri=model_uri or "",
                            implementation=implementation or "",
                            envSecretRefName=secret_name,
                        ),
                    )
                ],
            ),
        )

    def is_managed_by_zenml(self) -> bool:
        """Checks if this Seldon Deployment is managed by ZenML.

        The convention used to differentiate between SeldonDeployment instances
        that are managed by ZenML and those that are not is to set the `app`
        label value to `zenml`.

        Returns:
            True if the Seldon Deployment is managed by ZenML, False
            otherwise.
        """
        return self.metadata.labels.get("app") == "zenml"

    def mark_as_managed_by_zenml(self) -> None:
        """Marks this Seldon Deployment as managed by ZenML.

        The convention used to differentiate between SeldonDeployment instances
        that are managed by ZenML and those that are not is to set the `app`
        label value to `zenml`.
        """
        self.metadata.labels["app"] = "zenml"

    @property
    def name(self) -> str:
        """Returns the name of this Seldon Deployment.

        This is just a shortcut for `self.metadata.name`.

        Returns:
            The name of this Seldon Deployment.
        """
        return self.metadata.name

    @property
    def state(self) -> SeldonDeploymentStatusState:
        """The state of the Seldon Deployment.

        Returns:
            The state of the Seldon Deployment.
        """
        if not self.status:
            return SeldonDeploymentStatusState.UNKNOWN
        return self.status.state

    def is_pending(self) -> bool:
        """Checks if the Seldon Deployment is in a pending state.

        Returns:
            True if the Seldon Deployment is pending, False otherwise.
        """
        return self.state == SeldonDeploymentStatusState.CREATING

    def is_available(self) -> bool:
        """Checks if the Seldon Deployment is in an available state.

        Returns:
            True if the Seldon Deployment is available, False otherwise.
        """
        return self.state == SeldonDeploymentStatusState.AVAILABLE

    def is_failed(self) -> bool:
        """Checks if the Seldon Deployment is in a failed state.

        Returns:
            True if the Seldon Deployment is failed, False otherwise.
        """
        return self.state == SeldonDeploymentStatusState.FAILED

    def get_error(self) -> Optional[str]:
        """Get a message describing the error, if in an error state.

        Returns:
            A message describing the error, if in an error state, otherwise
            None.
        """
        if self.status and self.is_failed():
            return self.status.description
        return None

    def get_pending_message(self) -> Optional[str]:
        """Get a message describing the pending conditions of the Seldon Deployment.

        Returns:
            A message describing the pending condition of the Seldon
            Deployment, or None, if no conditions are pending.
        """
        if not self.status or not self.status.conditions:
            return None
        ready_condition_message = [
            c.message
            for c in self.status.conditions
            if c.type == "Ready" and not c.status
        ]
        if not ready_condition_message:
            return None
        return ready_condition_message[0]

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
name: str property readonly

Returns the name of this Seldon Deployment.

This is just a shortcut for self.metadata.name.

Returns:

Type Description
str

The name of this Seldon Deployment.

state: SeldonDeploymentStatusState property readonly

The state of the Seldon Deployment.

Returns:

Type Description
SeldonDeploymentStatusState

The state of the Seldon Deployment.

Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
__str__(self) special

Returns a string representation of the Seldon Deployment.

Returns:

Type Description
str

A string representation of the Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def __str__(self) -> str:
    """Returns a string representation of the Seldon Deployment.

    Returns:
        A string representation of the Seldon Deployment.
    """
    return json.dumps(self.dict(exclude_none=True), indent=4)
build(name=None, model_uri=None, model_name=None, implementation=None, secret_name=None, labels=None, annotations=None) classmethod

Build a basic Seldon Deployment object.

Parameters:

Name Type Description Default
name Optional[str]

The name of the Seldon Deployment. If not explicitly passed, a unique name is autogenerated.

None
model_uri Optional[str]

The URI of the model.

None
model_name Optional[str]

The name of the model.

None
implementation Optional[str]

The implementation of the model.

None
secret_name Optional[str]

The name of the Kubernetes secret containing environment variable values (e.g. with credentials for the artifact store) to use with the deployment service.

None
labels Optional[Dict[str, str]]

A dictionary of labels to apply to the Seldon Deployment.

None
annotations Optional[Dict[str, str]]

A dictionary of annotations to apply to the Seldon Deployment.

None

Returns:

Type Description
SeldonDeployment

A minimal SeldonDeployment object built from the provided parameters.

Source code in zenml/integrations/seldon/seldon_client.py
@classmethod
def build(
    cls,
    name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_name: Optional[str] = None,
    implementation: Optional[str] = None,
    secret_name: Optional[str] = None,
    labels: Optional[Dict[str, str]] = None,
    annotations: Optional[Dict[str, str]] = None,
) -> "SeldonDeployment":
    """Build a basic Seldon Deployment object.

    Args:
        name: The name of the Seldon Deployment. If not explicitly passed,
            a unique name is autogenerated.
        model_uri: The URI of the model.
        model_name: The name of the model.
        implementation: The implementation of the model.
        secret_name: The name of the Kubernetes secret containing
            environment variable values (e.g. with credentials for the
            artifact store) to use with the deployment service.
        labels: A dictionary of labels to apply to the Seldon Deployment.
        annotations: A dictionary of annotations to apply to the Seldon
            Deployment.

    Returns:
        A minimal SeldonDeployment object built from the provided
        parameters.
    """
    if not name:
        name = f"zenml-{time.time()}"

    if labels is None:
        labels = {}
    if annotations is None:
        annotations = {}

    return SeldonDeployment(
        metadata=SeldonDeploymentMetadata(
            name=name, labels=labels, annotations=annotations
        ),
        spec=SeldonDeploymentSpec(
            name=name,
            predictors=[
                SeldonDeploymentPredictor(
                    name=model_name or "",
                    graph=SeldonDeploymentPredictiveUnit(
                        name="default",
                        type=SeldonDeploymentPredictiveUnitType.MODEL,
                        modelUri=model_uri or "",
                        implementation=implementation or "",
                        envSecretRefName=secret_name,
                    ),
                )
            ],
        ),
    )
get_error(self)

Get a message describing the error, if in an error state.

Returns:

Type Description
Optional[str]

A message describing the error, if in an error state, otherwise None.

Source code in zenml/integrations/seldon/seldon_client.py
def get_error(self) -> Optional[str]:
    """Get a message describing the error, if in an error state.

    Returns:
        A message describing the error, if in an error state, otherwise
        None.
    """
    if self.status and self.is_failed():
        return self.status.description
    return None
get_pending_message(self)

Get a message describing the pending conditions of the Seldon Deployment.

Returns:

Type Description
Optional[str]

A message describing the pending condition of the Seldon Deployment, or None, if no conditions are pending.

Source code in zenml/integrations/seldon/seldon_client.py
def get_pending_message(self) -> Optional[str]:
    """Get a message describing the pending conditions of the Seldon Deployment.

    Returns:
        A message describing the pending condition of the Seldon
        Deployment, or None, if no conditions are pending.
    """
    if not self.status or not self.status.conditions:
        return None
    ready_condition_message = [
        c.message
        for c in self.status.conditions
        if c.type == "Ready" and not c.status
    ]
    if not ready_condition_message:
        return None
    return ready_condition_message[0]
is_available(self)

Checks if the Seldon Deployment is in an available state.

Returns:

Type Description
bool

True if the Seldon Deployment is available, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_available(self) -> bool:
    """Checks if the Seldon Deployment is in an available state.

    Returns:
        True if the Seldon Deployment is available, False otherwise.
    """
    return self.state == SeldonDeploymentStatusState.AVAILABLE
is_failed(self)

Checks if the Seldon Deployment is in a failed state.

Returns:

Type Description
bool

True if the Seldon Deployment is failed, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_failed(self) -> bool:
    """Checks if the Seldon Deployment is in a failed state.

    Returns:
        True if the Seldon Deployment is failed, False otherwise.
    """
    return self.state == SeldonDeploymentStatusState.FAILED
is_managed_by_zenml(self)

Checks if this Seldon Deployment is managed by ZenML.

The convention used to differentiate between SeldonDeployment instances that are managed by ZenML and those that are not is to set the app label value to zenml.

Returns:

Type Description
bool

True if the Seldon Deployment is managed by ZenML, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_managed_by_zenml(self) -> bool:
    """Checks if this Seldon Deployment is managed by ZenML.

    The convention used to differentiate between SeldonDeployment instances
    that are managed by ZenML and those that are not is to set the `app`
    label value to `zenml`.

    Returns:
        True if the Seldon Deployment is managed by ZenML, False
        otherwise.
    """
    return self.metadata.labels.get("app") == "zenml"
is_pending(self)

Checks if the Seldon Deployment is in a pending state.

Returns:

Type Description
bool

True if the Seldon Deployment is pending, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_pending(self) -> bool:
    """Checks if the Seldon Deployment is in a pending state.

    Returns:
        True if the Seldon Deployment is pending, False otherwise.
    """
    return self.state == SeldonDeploymentStatusState.CREATING
mark_as_managed_by_zenml(self)

Marks this Seldon Deployment as managed by ZenML.

The convention used to differentiate between SeldonDeployment instances that are managed by ZenML and those that are not is to set the app label value to zenml.

Source code in zenml/integrations/seldon/seldon_client.py
def mark_as_managed_by_zenml(self) -> None:
    """Marks this Seldon Deployment as managed by ZenML.

    The convention used to differentiate between SeldonDeployment instances
    that are managed by ZenML and those that are not is to set the `app`
    label value to `zenml`.
    """
    self.metadata.labels["app"] = "zenml"
SeldonDeploymentExistsError (SeldonClientError)

Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentExistsError(SeldonClientError):
    """Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists."""
SeldonDeploymentMetadata (BaseModel) pydantic-model

Metadata for a Seldon Deployment.

Attributes:

Name Type Description
name str

the name of the Seldon Deployment.

labels Dict[str, str]

Kubernetes labels for the Seldon Deployment.

annotations Dict[str, str]

Kubernetes annotations for the Seldon Deployment.

creationTimestamp Optional[str]

the creation timestamp of the Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentMetadata(BaseModel):
    """Metadata for a Seldon Deployment.

    Attributes:
        name: the name of the Seldon Deployment.
        labels: Kubernetes labels for the Seldon Deployment.
        annotations: Kubernetes annotations for the Seldon Deployment.
        creationTimestamp: the creation timestamp of the Seldon Deployment.
    """

    name: str
    labels: Dict[str, str] = Field(default_factory=dict)
    annotations: Dict[str, str] = Field(default_factory=dict)
    creationTimestamp: Optional[str]

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentNotFoundError (SeldonClientError)

Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentNotFoundError(SeldonClientError):
    """Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML."""
SeldonDeploymentPredictiveUnit (BaseModel) pydantic-model

Seldon Deployment predictive unit.

Attributes:

Name Type Description
name str

the name of the predictive unit.

type Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnitType]

predictive unit type.

implementation Optional[str]

the Seldon Core implementation used to serve the model.

modelUri Optional[str]

URI of the model (or models) to serve.

serviceAccountName Optional[str]

the name of the service account to associate with the predictive unit container.

envSecretRefName Optional[str]

the name of a Kubernetes secret that contains environment variables (e.g. credentials) to be configured for the predictive unit container.

children List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnit]

a list of child predictive units that together make up the model serving graph.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnit(BaseModel):
    """Seldon Deployment predictive unit.

    Attributes:
        name: the name of the predictive unit.
        type: predictive unit type.
        implementation: the Seldon Core implementation used to serve the model.
        modelUri: URI of the model (or models) to serve.
        serviceAccountName: the name of the service account to associate with
            the predictive unit container.
        envSecretRefName: the name of a Kubernetes secret that contains
            environment variables (e.g. credentials) to be configured for the
            predictive unit container.
        children: a list of child predictive units that together make up the
            model serving graph.
    """

    name: str
    type: Optional[
        SeldonDeploymentPredictiveUnitType
    ] = SeldonDeploymentPredictiveUnitType.MODEL
    implementation: Optional[str]
    modelUri: Optional[str]
    serviceAccountName: Optional[str]
    envSecretRefName: Optional[str]
    children: List["SeldonDeploymentPredictiveUnit"] = Field(
        default_factory=list
    )

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentPredictiveUnitType (StrEnum)

Predictive unit types for a Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnitType(StrEnum):
    """Predictive unit types for a Seldon Deployment."""

    UNKNOWN_TYPE = "UNKNOWN_TYPE"
    ROUTER = "ROUTER"
    COMBINER = "COMBINER"
    MODEL = "MODEL"
    TRANSFORMER = "TRANSFORMER"
    OUTPUT_TRANSFORMER = "OUTPUT_TRANSFORMER"
SeldonDeploymentPredictor (BaseModel) pydantic-model

Seldon Deployment predictor.

Attributes:

Name Type Description
name str

the name of the predictor.

replicas int

the number of pod replicas for the predictor.

graph SeldonDeploymentPredictiveUnit

the serving graph composed of one or more predictive units.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictor(BaseModel):
    """Seldon Deployment predictor.

    Attributes:
        name: the name of the predictor.
        replicas: the number of pod replicas for the predictor.
        graph: the serving graph composed of one or more predictive units.
    """

    name: str
    replicas: int = 1
    graph: SeldonDeploymentPredictiveUnit = Field(
        default_factory=SeldonDeploymentPredictiveUnit
    )

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentSpec (BaseModel) pydantic-model

Spec for a Seldon Deployment.

Attributes:

Name Type Description
name str

the name of the Seldon Deployment.

protocol Optional[str]

the API protocol used for the Seldon Deployment.

predictors List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictor]

a list of predictors that make up the serving graph.

replicas int

the default number of pod replicas used for the predictors.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentSpec(BaseModel):
    """Spec for a Seldon Deployment.

    Attributes:
        name: the name of the Seldon Deployment.
        protocol: the API protocol used for the Seldon Deployment.
        predictors: a list of predictors that make up the serving graph.
        replicas: the default number of pod replicas used for the predictors.
    """

    name: str
    protocol: Optional[str]
    predictors: List[SeldonDeploymentPredictor]
    replicas: int = 1

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentStatus (BaseModel) pydantic-model

The status of a Seldon Deployment.

Attributes:

Name Type Description
state SeldonDeploymentStatusState

the current state of the Seldon Deployment.

description Optional[str]

a human-readable description of the current state.

replicas Optional[int]

the current number of running pod replicas

address Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusAddress]

the address where the Seldon Deployment API can be accessed.

conditions List[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusCondition]

the list of Kubernetes conditions for the Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatus(BaseModel):
    """The status of a Seldon Deployment.

    Attributes:
        state: the current state of the Seldon Deployment.
        description: a human-readable description of the current state.
        replicas: the current number of running pod replicas
        address: the address where the Seldon Deployment API can be accessed.
        conditions: the list of Kubernetes conditions for the Seldon Deployment.
    """

    state: SeldonDeploymentStatusState = SeldonDeploymentStatusState.UNKNOWN
    description: Optional[str]
    replicas: Optional[int]
    address: Optional[SeldonDeploymentStatusAddress]
    conditions: List[SeldonDeploymentStatusCondition]

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentStatusAddress (BaseModel) pydantic-model

The status address for a Seldon Deployment.

Attributes:

Name Type Description
url str

the URL where the Seldon Deployment API can be accessed internally.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusAddress(BaseModel):
    """The status address for a Seldon Deployment.

    Attributes:
        url: the URL where the Seldon Deployment API can be accessed internally.
    """

    url: str
SeldonDeploymentStatusCondition (BaseModel) pydantic-model

The Kubernetes status condition entry for a Seldon Deployment.

Attributes:

Name Type Description
type str

Type of runtime condition.

status bool

Status of the condition.

reason Optional[str]

Brief CamelCase string containing reason for the condition's last transition.

message Optional[str]

Human-readable message indicating details about last transition.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusCondition(BaseModel):
    """The Kubernetes status condition entry for a Seldon Deployment.

    Attributes:
        type: Type of runtime condition.
        status: Status of the condition.
        reason: Brief CamelCase string containing reason for the condition's
            last transition.
        message: Human-readable message indicating details about last
            transition.
    """

    type: str
    status: bool
    reason: Optional[str]
    message: Optional[str]
SeldonDeploymentStatusState (StrEnum)

Possible state values for a Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusState(StrEnum):
    """Possible state values for a Seldon Deployment."""

    UNKNOWN = "Unknown"
    AVAILABLE = "Available"
    CREATING = "Creating"
    FAILED = "Failed"

services special

Initialization for Seldon services.

seldon_deployment

Implementation for the Seldon Deployment step.

SeldonDeploymentConfig (ServiceConfig) pydantic-model

Seldon Core deployment service configuration.

Attributes:

Name Type Description
model_uri str

URI of the model (or models) to serve.

model_name str

the name of the model. Multiple versions of the same model should use the same model name.

implementation str

the Seldon Core implementation used to serve the model.

replicas int

number of replicas to use for the prediction service.

secret_name Optional[str]

the name of a Kubernetes secret containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store).

model_metadata Dict[str, Any]

optional model metadata information (see https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html).

extra_args Dict[str, Any]

additional arguments to pass to the Seldon Core deployment resource configuration.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentConfig(ServiceConfig):
    """Seldon Core deployment service configuration.

    Attributes:
        model_uri: URI of the model (or models) to serve.
        model_name: the name of the model. Multiple versions of the same model
            should use the same model name.
        implementation: the Seldon Core implementation used to serve the model.
        replicas: number of replicas to use for the prediction service.
        secret_name: the name of a Kubernetes secret containing additional
            configuration parameters for the Seldon Core deployment (e.g.
            credentials to access the Artifact Store).
        model_metadata: optional model metadata information (see
            https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html).
        extra_args: additional arguments to pass to the Seldon Core deployment
            resource configuration.
    """

    model_uri: str = ""
    model_name: str = "default"
    # TODO [ENG-775]: have an enum of all supported Seldon Core implementations
    implementation: str
    replicas: int = 1
    secret_name: Optional[str]
    model_metadata: Dict[str, Any] = Field(default_factory=dict)
    extra_args: Dict[str, Any] = Field(default_factory=dict)

    def get_seldon_deployment_labels(self) -> Dict[str, str]:
        """Generate labels for the Seldon Core deployment from the service configuration.

        These labels are attached to the Seldon Core deployment resource
        and may be used as label selectors in lookup operations.

        Returns:
            The labels for the Seldon Core deployment.
        """
        labels = {}
        if self.pipeline_name:
            labels["zenml.pipeline_name"] = self.pipeline_name
        if self.pipeline_run_id:
            labels["zenml.pipeline_run_id"] = self.pipeline_run_id
        if self.pipeline_step_name:
            labels["zenml.pipeline_step_name"] = self.pipeline_step_name
        if self.model_name:
            labels["zenml.model_name"] = self.model_name
        if self.model_uri:
            labels["zenml.model_uri"] = self.model_uri
        if self.implementation:
            labels["zenml.model_type"] = self.implementation
        SeldonClient.sanitize_labels(labels)
        return labels

    def get_seldon_deployment_annotations(self) -> Dict[str, str]:
        """Generate annotations for the Seldon Core deployment from the service configuration.

        The annotations are used to store additional information about the
        Seldon Core service that is associated with the deployment that is
        not available in the labels. One annotation particularly important
        is the serialized Service configuration itself, which is used to
        recreate the service configuration from a remote Seldon deployment.

        Returns:
            The annotations for the Seldon Core deployment.
        """
        annotations = {
            "zenml.service_config": self.json(),
            "zenml.version": __version__,
        }
        return annotations

    @classmethod
    def create_from_deployment(
        cls, deployment: SeldonDeployment
    ) -> "SeldonDeploymentConfig":
        """Recreate the configuration of a Seldon Core Service from a deployed instance.

        Args:
            deployment: the Seldon Core deployment resource.

        Returns:
            The Seldon Core service configuration corresponding to the given
            Seldon Core deployment resource.

        Raises:
            ValueError: if the given deployment resource does not contain
                the expected annotations or it contains an invalid or
                incompatible Seldon Core service configuration.
        """
        config_data = deployment.metadata.annotations.get(
            "zenml.service_config"
        )
        if not config_data:
            raise ValueError(
                f"The given deployment resource does not contain a "
                f"'zenml.service_config' annotation: {deployment}"
            )
        try:
            service_config = cls.parse_raw(config_data)
        except ValidationError as e:
            raise ValueError(
                f"The loaded Seldon Core deployment resource contains an "
                f"invalid or incompatible Seldon Core service configuration: "
                f"{config_data}"
            ) from e
        return service_config
create_from_deployment(deployment) classmethod

Recreate the configuration of a Seldon Core Service from a deployed instance.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource.

required

Returns:

Type Description
SeldonDeploymentConfig

The Seldon Core service configuration corresponding to the given Seldon Core deployment resource.

Exceptions:

Type Description
ValueError

if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible Seldon Core service configuration.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
    cls, deployment: SeldonDeployment
) -> "SeldonDeploymentConfig":
    """Recreate the configuration of a Seldon Core Service from a deployed instance.

    Args:
        deployment: the Seldon Core deployment resource.

    Returns:
        The Seldon Core service configuration corresponding to the given
        Seldon Core deployment resource.

    Raises:
        ValueError: if the given deployment resource does not contain
            the expected annotations or it contains an invalid or
            incompatible Seldon Core service configuration.
    """
    config_data = deployment.metadata.annotations.get(
        "zenml.service_config"
    )
    if not config_data:
        raise ValueError(
            f"The given deployment resource does not contain a "
            f"'zenml.service_config' annotation: {deployment}"
        )
    try:
        service_config = cls.parse_raw(config_data)
    except ValidationError as e:
        raise ValueError(
            f"The loaded Seldon Core deployment resource contains an "
            f"invalid or incompatible Seldon Core service configuration: "
            f"{config_data}"
        ) from e
    return service_config
get_seldon_deployment_annotations(self)

Generate annotations for the Seldon Core deployment from the service configuration.

The annotations are used to store additional information about the Seldon Core service that is associated with the deployment that is not available in the labels. One annotation particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote Seldon deployment.

Returns:

Type Description
Dict[str, str]

The annotations for the Seldon Core deployment.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_annotations(self) -> Dict[str, str]:
    """Generate annotations for the Seldon Core deployment from the service configuration.

    The annotations are used to store additional information about the
    Seldon Core service that is associated with the deployment that is
    not available in the labels. One annotation particularly important
    is the serialized Service configuration itself, which is used to
    recreate the service configuration from a remote Seldon deployment.

    Returns:
        The annotations for the Seldon Core deployment.
    """
    annotations = {
        "zenml.service_config": self.json(),
        "zenml.version": __version__,
    }
    return annotations
get_seldon_deployment_labels(self)

Generate labels for the Seldon Core deployment from the service configuration.

These labels are attached to the Seldon Core deployment resource and may be used as label selectors in lookup operations.

Returns:

Type Description
Dict[str, str]

The labels for the Seldon Core deployment.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_labels(self) -> Dict[str, str]:
    """Generate labels for the Seldon Core deployment from the service configuration.

    These labels are attached to the Seldon Core deployment resource
    and may be used as label selectors in lookup operations.

    Returns:
        The labels for the Seldon Core deployment.
    """
    labels = {}
    if self.pipeline_name:
        labels["zenml.pipeline_name"] = self.pipeline_name
    if self.pipeline_run_id:
        labels["zenml.pipeline_run_id"] = self.pipeline_run_id
    if self.pipeline_step_name:
        labels["zenml.pipeline_step_name"] = self.pipeline_step_name
    if self.model_name:
        labels["zenml.model_name"] = self.model_name
    if self.model_uri:
        labels["zenml.model_uri"] = self.model_uri
    if self.implementation:
        labels["zenml.model_type"] = self.implementation
    SeldonClient.sanitize_labels(labels)
    return labels
SeldonDeploymentService (BaseService) pydantic-model

A service that represents a Seldon Core deployment server.

Attributes:

Name Type Description
config SeldonDeploymentConfig

service configuration.

status SeldonDeploymentServiceStatus

service status.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentService(BaseService):
    """A service that represents a Seldon Core deployment server.

    Attributes:
        config: service configuration.
        status: service status.
    """

    SERVICE_TYPE = ServiceType(
        name="seldon-deployment",
        type="model-serving",
        flavor="seldon",
        description="Seldon Core prediction service",
    )

    config: SeldonDeploymentConfig = Field(
        default_factory=SeldonDeploymentConfig
    )
    status: SeldonDeploymentServiceStatus = Field(
        default_factory=SeldonDeploymentServiceStatus
    )

    def _get_client(self) -> SeldonClient:
        """Get the Seldon Core client from the active Seldon Core model deployer.

        Returns:
            The Seldon Core client.
        """
        from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
            SeldonModelDeployer,
        )

        model_deployer = SeldonModelDeployer.get_active_model_deployer()
        return model_deployer.seldon_client

    def check_status(self) -> Tuple[ServiceState, str]:
        """Check the the current operational state of the Seldon Core deployment.

        Returns:
            The operational state of the Seldon Core deployment and a message
            providing additional information about that state (e.g. a
            description of the error, if one is encountered).
        """
        client = self._get_client()
        name = self.seldon_deployment_name
        try:
            deployment = client.get_deployment(name=name)
        except SeldonDeploymentNotFoundError:
            return (ServiceState.INACTIVE, "")

        if deployment.is_available():
            return (
                ServiceState.ACTIVE,
                f"Seldon Core deployment '{name}' is available",
            )

        if deployment.is_failed():
            return (
                ServiceState.ERROR,
                f"Seldon Core deployment '{name}' failed: "
                f"{deployment.get_error()}",
            )

        pending_message = deployment.get_pending_message() or ""
        return (
            ServiceState.PENDING_STARTUP,
            "Seldon Core deployment is being created: " + pending_message,
        )

    @property
    def seldon_deployment_name(self) -> str:
        """Get the name of the Seldon Core deployment.

        It should return the one that uniquely corresponds to this service instance.

        Returns:
            The name of the Seldon Core deployment.
        """
        return f"zenml-{str(self.uuid)}"

    def _get_seldon_deployment_labels(self) -> Dict[str, str]:
        """Generate the labels for the Seldon Core deployment from the service configuration.

        Returns:
            The labels for the Seldon Core deployment.
        """
        labels = self.config.get_seldon_deployment_labels()
        labels["zenml.service_uuid"] = str(self.uuid)
        SeldonClient.sanitize_labels(labels)
        return labels

    @classmethod
    def create_from_deployment(
        cls, deployment: SeldonDeployment
    ) -> "SeldonDeploymentService":
        """Recreate a Seldon Core service from a Seldon Core deployment resource.

        It should then update their operational status.

        Args:
            deployment: the Seldon Core deployment resource.

        Returns:
            The Seldon Core service corresponding to the given
            Seldon Core deployment resource.

        Raises:
            ValueError: if the given deployment resource does not contain
                the expected service_uuid label.
        """
        config = SeldonDeploymentConfig.create_from_deployment(deployment)
        uuid = deployment.metadata.labels.get("zenml.service_uuid")
        if not uuid:
            raise ValueError(
                f"The given deployment resource does not contain a valid "
                f"'zenml.service_uuid' label: {deployment}"
            )
        service = cls(uuid=UUID(uuid), config=config)
        service.update_status()
        return service

    def provision(self) -> None:
        """Provision or update remote Seldon Core deployment instance.

        This should then match the current configuration.
        """
        client = self._get_client()

        name = self.seldon_deployment_name

        deployment = SeldonDeployment.build(
            name=name,
            model_uri=self.config.model_uri,
            model_name=self.config.model_name,
            implementation=self.config.implementation,
            secret_name=self.config.secret_name,
            labels=self._get_seldon_deployment_labels(),
            annotations=self.config.get_seldon_deployment_annotations(),
        )
        deployment.spec.replicas = self.config.replicas
        deployment.spec.predictors[0].replicas = self.config.replicas

        # check if the Seldon deployment already exists
        try:
            client.get_deployment(name=name)
            # update the existing deployment
            client.update_deployment(deployment)
        except SeldonDeploymentNotFoundError:
            # create the deployment
            client.create_deployment(deployment=deployment)

    def deprovision(self, force: bool = False) -> None:
        """Deprovision the remote Seldon Core deployment instance.

        Args:
            force: if True, the remote deployment instance will be
                forcefully deprovisioned.
        """
        client = self._get_client()
        name = self.seldon_deployment_name
        try:
            client.delete_deployment(name=name, force=force)
        except SeldonDeploymentNotFoundError:
            pass

    def get_logs(
        self,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Get the logs of a Seldon Core model deployment.

        Args:
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.
        """
        return self._get_client().get_deployment_logs(
            self.seldon_deployment_name,
            follow=follow,
            tail=tail,
        )

    @property
    def prediction_url(self) -> Optional[str]:
        """The prediction URI exposed by the prediction service.

        Returns:
            The prediction URI exposed by the prediction service, or None if
            the service is not yet ready.
        """
        from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
            SeldonModelDeployer,
        )

        if not self.is_running:
            return None
        namespace = self._get_client().namespace
        model_deployer = SeldonModelDeployer.get_active_model_deployer()
        return os.path.join(
            model_deployer.base_url,
            "seldon",
            namespace,
            self.seldon_deployment_name,
            "api/v0.1/predictions",
        )

    def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
        """Make a prediction using the service.

        Args:
            request: a numpy array representing the request

        Returns:
            A numpy array representing the prediction returned by the service.

        Raises:
            Exception: if the service is not yet ready.
            ValueError: if the prediction_url is not set.
        """
        if not self.is_running:
            raise Exception(
                "Seldon prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.prediction_url is None:
            raise ValueError("`self.prediction_url` is not set, cannot post.")
        response = requests.post(
            self.prediction_url,
            json={"data": {"ndarray": request.tolist()}},
        )
        response.raise_for_status()
        return np.array(response.json()["data"]["ndarray"])
prediction_url: Optional[str] property readonly

The prediction URI exposed by the prediction service.

Returns:

Type Description
Optional[str]

The prediction URI exposed by the prediction service, or None if the service is not yet ready.

seldon_deployment_name: str property readonly

Get the name of the Seldon Core deployment.

It should return the one that uniquely corresponds to this service instance.

Returns:

Type Description
str

The name of the Seldon Core deployment.

check_status(self)

Check the the current operational state of the Seldon Core deployment.

Returns:

Type Description
Tuple[zenml.services.service_status.ServiceState, str]

The operational state of the Seldon Core deployment and a message providing additional information about that state (e.g. a description of the error, if one is encountered).

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
    """Check the the current operational state of the Seldon Core deployment.

    Returns:
        The operational state of the Seldon Core deployment and a message
        providing additional information about that state (e.g. a
        description of the error, if one is encountered).
    """
    client = self._get_client()
    name = self.seldon_deployment_name
    try:
        deployment = client.get_deployment(name=name)
    except SeldonDeploymentNotFoundError:
        return (ServiceState.INACTIVE, "")

    if deployment.is_available():
        return (
            ServiceState.ACTIVE,
            f"Seldon Core deployment '{name}' is available",
        )

    if deployment.is_failed():
        return (
            ServiceState.ERROR,
            f"Seldon Core deployment '{name}' failed: "
            f"{deployment.get_error()}",
        )

    pending_message = deployment.get_pending_message() or ""
    return (
        ServiceState.PENDING_STARTUP,
        "Seldon Core deployment is being created: " + pending_message,
    )
create_from_deployment(deployment) classmethod

Recreate a Seldon Core service from a Seldon Core deployment resource.

It should then update their operational status.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource.

required

Returns:

Type Description
SeldonDeploymentService

The Seldon Core service corresponding to the given Seldon Core deployment resource.

Exceptions:

Type Description
ValueError

if the given deployment resource does not contain the expected service_uuid label.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
    cls, deployment: SeldonDeployment
) -> "SeldonDeploymentService":
    """Recreate a Seldon Core service from a Seldon Core deployment resource.

    It should then update their operational status.

    Args:
        deployment: the Seldon Core deployment resource.

    Returns:
        The Seldon Core service corresponding to the given
        Seldon Core deployment resource.

    Raises:
        ValueError: if the given deployment resource does not contain
            the expected service_uuid label.
    """
    config = SeldonDeploymentConfig.create_from_deployment(deployment)
    uuid = deployment.metadata.labels.get("zenml.service_uuid")
    if not uuid:
        raise ValueError(
            f"The given deployment resource does not contain a valid "
            f"'zenml.service_uuid' label: {deployment}"
        )
    service = cls(uuid=UUID(uuid), config=config)
    service.update_status()
    return service
deprovision(self, force=False)

Deprovision the remote Seldon Core deployment instance.

Parameters:

Name Type Description Default
force bool

if True, the remote deployment instance will be forcefully deprovisioned.

False
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def deprovision(self, force: bool = False) -> None:
    """Deprovision the remote Seldon Core deployment instance.

    Args:
        force: if True, the remote deployment instance will be
            forcefully deprovisioned.
    """
    client = self._get_client()
    name = self.seldon_deployment_name
    try:
        client.delete_deployment(name=name, force=force)
    except SeldonDeploymentNotFoundError:
        pass
get_logs(self, follow=False, tail=None)

Get the logs of a Seldon Core model deployment.

Parameters:

Name Type Description Default
follow bool

if True, the logs will be streamed as they are written

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_logs(
    self,
    follow: bool = False,
    tail: Optional[int] = None,
) -> Generator[str, bool, None]:
    """Get the logs of a Seldon Core model deployment.

    Args:
        follow: if True, the logs will be streamed as they are written
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.
    """
    return self._get_client().get_deployment_logs(
        self.seldon_deployment_name,
        follow=follow,
        tail=tail,
    )
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request NDArray[Any]

a numpy array representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not yet ready.

ValueError

if the prediction_url is not set.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a numpy array representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not yet ready.
        ValueError: if the prediction_url is not set.
    """
    if not self.is_running:
        raise Exception(
            "Seldon prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.prediction_url is None:
        raise ValueError("`self.prediction_url` is not set, cannot post.")
    response = requests.post(
        self.prediction_url,
        json={"data": {"ndarray": request.tolist()}},
    )
    response.raise_for_status()
    return np.array(response.json()["data"]["ndarray"])
provision(self)

Provision or update remote Seldon Core deployment instance.

This should then match the current configuration.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def provision(self) -> None:
    """Provision or update remote Seldon Core deployment instance.

    This should then match the current configuration.
    """
    client = self._get_client()

    name = self.seldon_deployment_name

    deployment = SeldonDeployment.build(
        name=name,
        model_uri=self.config.model_uri,
        model_name=self.config.model_name,
        implementation=self.config.implementation,
        secret_name=self.config.secret_name,
        labels=self._get_seldon_deployment_labels(),
        annotations=self.config.get_seldon_deployment_annotations(),
    )
    deployment.spec.replicas = self.config.replicas
    deployment.spec.predictors[0].replicas = self.config.replicas

    # check if the Seldon deployment already exists
    try:
        client.get_deployment(name=name)
        # update the existing deployment
        client.update_deployment(deployment)
    except SeldonDeploymentNotFoundError:
        # create the deployment
        client.create_deployment(deployment=deployment)
SeldonDeploymentServiceStatus (ServiceStatus) pydantic-model

Seldon Core deployment service status.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentServiceStatus(ServiceStatus):
    """Seldon Core deployment service status."""

steps special

Initialization for Seldon steps.

seldon_deployer

Implementation of the Seldon Deployer step.

SeldonDeployerStepConfig (BaseStepConfig) pydantic-model

Seldon model deployer step configuration.

Attributes:

Name Type Description
service_config SeldonDeploymentConfig

Seldon Core deployment service configuration.

secrets

a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables.

Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepConfig(BaseStepConfig):
    """Seldon model deployer step configuration.

    Attributes:
        service_config: Seldon Core deployment service configuration.
        secrets: a list of ZenML secrets containing additional configuration
            parameters for the Seldon Core deployment (e.g. credentials to
            access the Artifact Store where the models are stored). If supplied,
            the information fetched from these secrets is passed to the Seldon
            Core deployment server as a list of environment variables.
    """

    service_config: SeldonDeploymentConfig
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
seldon_model_deployer_step (BaseStep)

Seldon Core model deployer pipeline step.

This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.

Parameters:

Name Type Description Default
deploy_decision

whether to deploy the model or not

required
config

configuration for the deployer step

required
model

the model artifact to deploy

required
context

the step context

required

Returns:

Type Description

Seldon Core deployment service

CONFIG_CLASS (BaseStepConfig) pydantic-model

Seldon model deployer step configuration.

Attributes:

Name Type Description
service_config SeldonDeploymentConfig

Seldon Core deployment service configuration.

secrets

a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables.

Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepConfig(BaseStepConfig):
    """Seldon model deployer step configuration.

    Attributes:
        service_config: Seldon Core deployment service configuration.
        secrets: a list of ZenML secrets containing additional configuration
            parameters for the Seldon Core deployment (e.g. credentials to
            access the Artifact Store where the models are stored). If supplied,
            the information fetched from these secrets is passed to the Seldon
            Core deployment server as a list of environment variables.
    """

    service_config: SeldonDeploymentConfig
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, config, context, model) staticmethod

Seldon Core model deployer pipeline step.

This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

required
config SeldonDeployerStepConfig

configuration for the deployer step

required
model ModelArtifact

the model artifact to deploy

required
context StepContext

the step context

required

Returns:

Type Description
SeldonDeploymentService

Seldon Core deployment service

Source code in zenml/integrations/seldon/steps/seldon_deployer.py
@step(enable_cache=False)
def seldon_model_deployer_step(
    deploy_decision: bool,
    config: SeldonDeployerStepConfig,
    context: StepContext,
    model: ModelArtifact,
) -> SeldonDeploymentService:
    """Seldon Core model deployer pipeline step.

    This step can be used in a pipeline to implement continuous
    deployment for a ML model with Seldon Core.

    Args:
        deploy_decision: whether to deploy the model or not
        config: configuration for the deployer step
        model: the model artifact to deploy
        context: the step context

    Returns:
        Seldon Core deployment service
    """
    model_deployer = SeldonModelDeployer.get_active_model_deployer()

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    pipeline_run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    # update the step configuration with the real pipeline runtime information
    config.service_config.pipeline_name = pipeline_name
    config.service_config.pipeline_run_id = pipeline_run_id
    config.service_config.pipeline_step_name = step_name

    def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig:
        """Prepare the model files for model serving.

        This creates and returns a Seldon service configuration for the model.

        This function ensures that the model files are in the correct format
        and file structure required by the Seldon Core server implementation
        used for model serving.

        Args:
            model_uri: the URI of the model artifact being served

        Returns:
            The URL to the model ready for serving.

        Raises:
            RuntimeError: if the model files were not found
        """
        served_model_uri = os.path.join(
            context.get_output_artifact_uri(), "seldon"
        )
        fileio.makedirs(served_model_uri)

        # TODO [ENG-773]: determine how to formalize how models are organized into
        #   folders and sub-folders depending on the model type/format and the
        #   Seldon Core protocol used to serve the model.

        # TODO [ENG-791]: auto-detect built-in Seldon server implementation
        #   from the model artifact type

        # TODO [ENG-792]: validate the model artifact type against the
        #   supported built-in Seldon server implementations
        if config.service_config.implementation == "TENSORFLOW_SERVER":
            # the TensorFlow server expects model artifacts to be
            # stored in numbered subdirectories, each representing a model
            # version
            io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
        elif config.service_config.implementation == "SKLEARN_SERVER":
            # the sklearn server expects model artifacts to be
            # stored in a file called model.joblib
            model_uri = os.path.join(model.uri, "model")
            if not fileio.exists(model.uri):
                raise RuntimeError(
                    f"Expected sklearn model artifact was not found at "
                    f"{model_uri}"
                )
            fileio.copy(
                model_uri, os.path.join(served_model_uri, "model.joblib")
            )
        else:
            # default treatment for all other server implementations is to
            # simply reuse the model from the artifact store path where it
            # is originally stored
            served_model_uri = model_uri

        service_config = config.service_config.copy()
        service_config.model_uri = served_model_uri
        return service_config

    # fetch existing services with same pipeline name, step name and
    # model name
    existing_services = model_deployer.find_model_server(
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        model_name=config.service_config.model_name,
    )

    # even when the deploy decision is negative, if an existing model server
    # is not running for this pipeline/step, we still have to serve the
    # current model, to ensure that a model server is available at all times
    if not deploy_decision and existing_services:
        logger.info(
            f"Skipping model deployment because the model quality does not "
            f"meet the criteria. Reusing last model server deployed by step "
            f"'{step_name}' and pipeline '{pipeline_name}' for model "
            f"'{config.service_config.model_name}'..."
        )
        service = cast(SeldonDeploymentService, existing_services[0])
        # even when the deploy decision is negative, we still need to start
        # the previous model server if it is no longer running, to ensure that
        # a model server is available at all times
        if not service.is_running:
            service.start(timeout=config.timeout)
        return service

    # invoke the Seldon Core model deployer to create a new service
    # or update an existing one that was previously deployed for the same
    # model
    service_config = prepare_service_config(model.uri)
    service = cast(
        SeldonDeploymentService,
        model_deployer.deploy_model(
            service_config, replace=True, timeout=config.timeout
        ),
    )

    logger.info(
        f"Seldon deployment service started and reachable at:\n"
        f"    {service.prediction_url}\n"
    )

    return service

sklearn special

Initialization of the sklearn integration.

SklearnIntegration (Integration)

Definition of sklearn integration for ZenML.

Source code in zenml/integrations/sklearn/__init__.py
class SklearnIntegration(Integration):
    """Definition of sklearn integration for ZenML."""

    NAME = SKLEARN
    REQUIREMENTS = ["scikit-learn"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.sklearn import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/sklearn/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.sklearn import materializers  # noqa

helpers special

Initialization for helper functions for the sklearn digits dataset.

digits

Helper functions for the sklearn digits dataset.

get_digits()

Returns the digits dataset in the form of a tuple of numpy arrays.

Returns:

Type Description
Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int64], NDArray[np.int64]]

Tuple of (training_images, testing_images, training_labels, testing_labels)

Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits() -> Tuple[
    "NDArray[np.float64]",
    "NDArray[np.float64]",
    "NDArray[np.int64]",
    "NDArray[np.int64]",
]:
    """Returns the digits dataset in the form of a tuple of numpy arrays.

    Returns:
        Tuple of (training_images, testing_images, training_labels, testing_labels)
    """
    digits = load_digits()
    # flatten the images
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    # Split data into 50% train and 50% test subsets
    X_train, X_test, y_train, y_test = train_test_split(
        data, digits.target, test_size=0.5, shuffle=False
    )
    return X_train, X_test, y_train, y_test
get_digits_model()

Creates a support vector classifier for digits dataset.

Returns:

Type Description
ClassifierMixin

A support vector classifier.

Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits_model() -> ClassifierMixin:
    """Creates a support vector classifier for digits dataset.

    Returns:
        A support vector classifier.
    """
    return SVC(gamma=0.001)

materializers special

Initialization of the sklearn materializer.

sklearn_materializer

Implementation of the sklearn materializer.

SklearnMaterializer (BaseMaterializer)

Materializer to read data to and from sklearn.

Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
class SklearnMaterializer(BaseMaterializer):
    """Materializer to read data to and from sklearn."""

    ASSOCIATED_TYPES = (
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    )
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(
        self, data_type: Type[Any]
    ) -> Union[
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ]:
        """Reads a base sklearn model from a pickle file.

        Args:
            data_type: The type of the model.

        Returns:
            The model.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        with fileio.open(filepath, "rb") as fid:
            clf = pickle.load(fid)
        return clf

    def handle_return(
        self,
        clf: Union[
            BaseEstimator,
            ClassifierMixin,
            ClusterMixin,
            BiclusterMixin,
            OutlierMixin,
            RegressorMixin,
            MetaEstimatorMixin,
            MultiOutputMixin,
            DensityMixin,
            TransformerMixin,
        ],
    ) -> None:
        """Creates a pickle for a sklearn model.

        Args:
            clf: A sklearn model.
        """
        super().handle_return(clf)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        with fileio.open(filepath, "wb") as fid:
            pickle.dump(clf, fid)
handle_input(self, data_type)

Reads a base sklearn model from a pickle file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model.

required

Returns:

Type Description
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin]

The model.

Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_input(
    self, data_type: Type[Any]
) -> Union[
    BaseEstimator,
    ClassifierMixin,
    ClusterMixin,
    BiclusterMixin,
    OutlierMixin,
    RegressorMixin,
    MetaEstimatorMixin,
    MultiOutputMixin,
    DensityMixin,
    TransformerMixin,
]:
    """Reads a base sklearn model from a pickle file.

    Args:
        data_type: The type of the model.

    Returns:
        The model.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    with fileio.open(filepath, "rb") as fid:
        clf = pickle.load(fid)
    return clf
handle_return(self, clf)

Creates a pickle for a sklearn model.

Parameters:

Name Type Description Default
clf Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin]

A sklearn model.

required
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_return(
    self,
    clf: Union[
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ],
) -> None:
    """Creates a pickle for a sklearn model.

    Args:
        clf: A sklearn model.
    """
    super().handle_return(clf)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    with fileio.open(filepath, "wb") as fid:
        pickle.dump(clf, fid)

steps special

Initialization of the sklearn standard steps.

sklearn_evaluator

Implementation of the sklearn evaluator step.

SklearnEvaluator (BaseEvaluatorStep)

Simple sklearn evaluator step implementation.

This uses sklearn to evaluate the performance of a given model on a given test dataset.

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluator(BaseEvaluatorStep):
    """Simple sklearn evaluator step implementation.

    This uses sklearn to evaluate the performance of a given model on a given
    test dataset.
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        model: BaseEstimator,
        config: SklearnEvaluatorConfig,
    ) -> dict:  # type: ignore[type-arg]
        """Method which is responsible for the computation of the evaluation.

        Args:
            dataset: a pandas DataFrame which represents the test dataset
            model: a trained sklearn model
            config: the configuration for the step

        Returns:
            a dictionary which has the evaluation report
        """
        labels = dataset.pop(config.label_class_column)

        predictions = model.predict(dataset)
        predicted_classes = [1 if v > 0.5 else 0 for v in predictions]

        report = classification_report(
            labels, predicted_classes, output_dict=True
        )

        return report  # type: ignore[no-any-return]
CONFIG_CLASS (BaseEvaluatorConfig) pydantic-model

Config class for the sklearn evaluator.

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
    """Config class for the sklearn evaluator."""

    label_class_column: str
entrypoint(self, dataset, model, config)

Method which is responsible for the computation of the evaluation.

Parameters:

Name Type Description Default
dataset DataFrame

a pandas DataFrame which represents the test dataset

required
model BaseEstimator

a trained sklearn model

required
config SklearnEvaluatorConfig

the configuration for the step

required

Returns:

Type Description
dict

a dictionary which has the evaluation report

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    model: BaseEstimator,
    config: SklearnEvaluatorConfig,
) -> dict:  # type: ignore[type-arg]
    """Method which is responsible for the computation of the evaluation.

    Args:
        dataset: a pandas DataFrame which represents the test dataset
        model: a trained sklearn model
        config: the configuration for the step

    Returns:
        a dictionary which has the evaluation report
    """
    labels = dataset.pop(config.label_class_column)

    predictions = model.predict(dataset)
    predicted_classes = [1 if v > 0.5 else 0 for v in predictions]

    report = classification_report(
        labels, predicted_classes, output_dict=True
    )

    return report  # type: ignore[no-any-return]
SklearnEvaluatorConfig (BaseEvaluatorConfig) pydantic-model

Config class for the sklearn evaluator.

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
    """Config class for the sklearn evaluator."""

    label_class_column: str
sklearn_splitter

Implementation of the sklearn splitter.

SklearnSplitter (BaseSplitStep)

A simple sklearn splitter step implementation.

This uses sklearn to split a given dataset into train, test and validation splits.

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitter(BaseSplitStep):
    """A simple sklearn splitter step implementation.

    This uses sklearn to split a given dataset into train, test and validation
    splits.
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: SklearnSplitterConfig,
    ) -> Output(  # type:ignore[valid-type]
        train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
    ):
        """Method which is responsible for the splitting logic.

        Args:
            dataset: a pandas DataFrame which entire dataset
            config: the configuration for the step

        Returns:
            three DataFrames representing the splits

        Raises:
            KeyError: if the wrong configuration is used
            ValueError: if the ratios are not valid
        """
        if (
            any(
                [
                    split not in config.ratios
                    for split in ["train", "test", "validation"]
                ]
            )
            or len(config.ratios) != 3
        ):
            raise KeyError(
                f"Make sure that you only use 'train', 'test' and "
                f"'validation' as keys in the ratios dict. Current keys: "
                f"{config.ratios.keys()}"
            )

        if sum(config.ratios.values()) != 1:
            raise ValueError(
                f"Make sure that the ratios sum up to 1. Current "
                f"ratios: {config.ratios}"
            )

        train_dataset, test_dataset = train_test_split(
            dataset, test_size=config.ratios["test"]
        )

        train_dataset, val_dataset = train_test_split(
            train_dataset,
            test_size=(
                config.ratios["validation"]
                / (config.ratios["validation"] + config.ratios["train"])
            ),
        )

        return train_dataset, test_dataset, val_dataset
CONFIG_CLASS (BaseSplitStepConfig) pydantic-model

Config class for the sklearn splitter.

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
    """Config class for the sklearn splitter."""

    ratios: Dict[str, float]
entrypoint(self, dataset, config)

Method which is responsible for the splitting logic.

Parameters:

Name Type Description Default
dataset DataFrame

a pandas DataFrame which entire dataset

required
config SklearnSplitterConfig

the configuration for the step

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7fd4c867dcd0>

three DataFrames representing the splits

Exceptions:

Type Description
KeyError

if the wrong configuration is used

ValueError

if the ratios are not valid

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: SklearnSplitterConfig,
) -> Output(  # type:ignore[valid-type]
    train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
    """Method which is responsible for the splitting logic.

    Args:
        dataset: a pandas DataFrame which entire dataset
        config: the configuration for the step

    Returns:
        three DataFrames representing the splits

    Raises:
        KeyError: if the wrong configuration is used
        ValueError: if the ratios are not valid
    """
    if (
        any(
            [
                split not in config.ratios
                for split in ["train", "test", "validation"]
            ]
        )
        or len(config.ratios) != 3
    ):
        raise KeyError(
            f"Make sure that you only use 'train', 'test' and "
            f"'validation' as keys in the ratios dict. Current keys: "
            f"{config.ratios.keys()}"
        )

    if sum(config.ratios.values()) != 1:
        raise ValueError(
            f"Make sure that the ratios sum up to 1. Current "
            f"ratios: {config.ratios}"
        )

    train_dataset, test_dataset = train_test_split(
        dataset, test_size=config.ratios["test"]
    )

    train_dataset, val_dataset = train_test_split(
        train_dataset,
        test_size=(
            config.ratios["validation"]
            / (config.ratios["validation"] + config.ratios["train"])
        ),
    )

    return train_dataset, test_dataset, val_dataset
SklearnSplitterConfig (BaseSplitStepConfig) pydantic-model

Config class for the sklearn splitter.

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
    """Config class for the sklearn splitter."""

    ratios: Dict[str, float]
sklearn_standard_scaler

Implementation of the sklearn standard scaler step.

SklearnStandardScaler (BasePreprocessorStep)

Simple StandardScaler step implementation.

This uses the StandardScaler from sklearn to transform the numeric columns of a pd.DataFrame.

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScaler(BasePreprocessorStep):
    """Simple StandardScaler step implementation.

    This uses the StandardScaler from sklearn to transform the numeric columns
    of a pd.DataFrame.
    """

    def entrypoint(  # type: ignore[override]
        self,
        train_dataset: pd.DataFrame,
        test_dataset: pd.DataFrame,
        validation_dataset: pd.DataFrame,
        statistics: pd.DataFrame,
        schema: pd.DataFrame,
        config: SklearnStandardScalerConfig,
    ) -> Output(  # type:ignore[valid-type]
        train_transformed=pd.DataFrame,
        test_transformed=pd.DataFrame,
        validation_transformed=pd.DataFrame,
    ):
        """Main entrypoint function for the StandardScaler.

        Args:
            train_dataset: pd.DataFrame, the training dataset
            test_dataset: pd.DataFrame, the test dataset
            validation_dataset: pd.DataFrame, the validation dataset
            statistics: pd.DataFrame, the statistics over the train dataset
            schema: pd.DataFrame, the detected schema of the dataset
            config: the configuration for the step

        Returns:
            the transformed train, test and validation datasets as pd.DataFrames
        """
        schema_dict = {k: v[0] for k, v in schema.to_dict().items()}

        # Exclude columns
        feature_set = set(train_dataset.columns) - set(config.exclude_columns)
        for feature, feature_type in schema_dict.items():
            if feature_type != "int64" and feature_type != "float64":
                feature_set.remove(feature)
                logger.warning(
                    f"{feature} column is a not numeric, thus it is excluded "
                    f"from the standard scaling."
                )

        transform_feature_set = feature_set - set(config.ignore_columns)

        # Transform the datasets
        scaler = StandardScaler()
        scaler.mean_ = statistics["mean"][transform_feature_set]
        scaler.scale_ = statistics["std"][transform_feature_set]

        train_dataset[list(transform_feature_set)] = scaler.transform(
            train_dataset[transform_feature_set]
        )
        test_dataset[list(transform_feature_set)] = scaler.transform(
            test_dataset[transform_feature_set]
        )
        validation_dataset[list(transform_feature_set)] = scaler.transform(
            validation_dataset[transform_feature_set]
        )

        return train_dataset, test_dataset, validation_dataset
CONFIG_CLASS (BasePreprocessorConfig) pydantic-model

Config class for the sklearn standard scaler.

ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
    """Config class for the sklearn standard scaler.

    ignore_columns: a list of column names which should not be scaled
    exclude_columns: a list of column names to be excluded from the dataset
    """

    ignore_columns: List[str] = []
    exclude_columns: List[str] = []
entrypoint(self, train_dataset, test_dataset, validation_dataset, statistics, schema, config)

Main entrypoint function for the StandardScaler.

Parameters:

Name Type Description Default
train_dataset DataFrame

pd.DataFrame, the training dataset

required
test_dataset DataFrame

pd.DataFrame, the test dataset

required
validation_dataset DataFrame

pd.DataFrame, the validation dataset

required
statistics DataFrame

pd.DataFrame, the statistics over the train dataset

required
schema DataFrame

pd.DataFrame, the detected schema of the dataset

required
config SklearnStandardScalerConfig

the configuration for the step

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7fd4c860ba90>

the transformed train, test and validation datasets as pd.DataFrames

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
def entrypoint(  # type: ignore[override]
    self,
    train_dataset: pd.DataFrame,
    test_dataset: pd.DataFrame,
    validation_dataset: pd.DataFrame,
    statistics: pd.DataFrame,
    schema: pd.DataFrame,
    config: SklearnStandardScalerConfig,
) -> Output(  # type:ignore[valid-type]
    train_transformed=pd.DataFrame,
    test_transformed=pd.DataFrame,
    validation_transformed=pd.DataFrame,
):
    """Main entrypoint function for the StandardScaler.

    Args:
        train_dataset: pd.DataFrame, the training dataset
        test_dataset: pd.DataFrame, the test dataset
        validation_dataset: pd.DataFrame, the validation dataset
        statistics: pd.DataFrame, the statistics over the train dataset
        schema: pd.DataFrame, the detected schema of the dataset
        config: the configuration for the step

    Returns:
        the transformed train, test and validation datasets as pd.DataFrames
    """
    schema_dict = {k: v[0] for k, v in schema.to_dict().items()}

    # Exclude columns
    feature_set = set(train_dataset.columns) - set(config.exclude_columns)
    for feature, feature_type in schema_dict.items():
        if feature_type != "int64" and feature_type != "float64":
            feature_set.remove(feature)
            logger.warning(
                f"{feature} column is a not numeric, thus it is excluded "
                f"from the standard scaling."
            )

    transform_feature_set = feature_set - set(config.ignore_columns)

    # Transform the datasets
    scaler = StandardScaler()
    scaler.mean_ = statistics["mean"][transform_feature_set]
    scaler.scale_ = statistics["std"][transform_feature_set]

    train_dataset[list(transform_feature_set)] = scaler.transform(
        train_dataset[transform_feature_set]
    )
    test_dataset[list(transform_feature_set)] = scaler.transform(
        test_dataset[transform_feature_set]
    )
    validation_dataset[list(transform_feature_set)] = scaler.transform(
        validation_dataset[transform_feature_set]
    )

    return train_dataset, test_dataset, validation_dataset
SklearnStandardScalerConfig (BasePreprocessorConfig) pydantic-model

Config class for the sklearn standard scaler.

ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
    """Config class for the sklearn standard scaler.

    ignore_columns: a list of column names which should not be scaled
    exclude_columns: a list of column names to be excluded from the dataset
    """

    ignore_columns: List[str] = []
    exclude_columns: List[str] = []

slack special

Slack integration for alerter components.

SlackIntegration (Integration)

Definition of a Slack integration for ZenML.

Implemented using Slack SDK.

Source code in zenml/integrations/slack/__init__.py
class SlackIntegration(Integration):
    """Definition of a Slack integration for ZenML.

    Implemented using [Slack SDK](https://pypi.org/project/slack-sdk/).
    """

    NAME = SLACK
    REQUIREMENTS = ["slack-sdk>=3.16.1", "aiohttp>=3.8.1"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Slack integration.

        Returns:
            List of new flavors defined by the Slack integration.
        """
        return [
            FlavorWrapper(
                name=SLACK_ALERTER_FLAVOR,
                source="zenml.integrations.slack.alerters.slack_alerter.SlackAlerter",
                type=StackComponentType.ALERTER,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Slack integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of new flavors defined by the Slack integration.

Source code in zenml/integrations/slack/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Slack integration.

    Returns:
        List of new flavors defined by the Slack integration.
    """
    return [
        FlavorWrapper(
            name=SLACK_ALERTER_FLAVOR,
            source="zenml.integrations.slack.alerters.slack_alerter.SlackAlerter",
            type=StackComponentType.ALERTER,
            integration=cls.NAME,
        )
    ]

alerters special

Alerter components defined by the Slack integration.

slack_alerter

Implementation for slack flavor of alerter component.

SlackAlerter (BaseAlerter) pydantic-model

Send messages to Slack channels.

Attributes:

Name Type Description
slack_token str

The Slack token tied to the Slack account to be used.

Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerter(BaseAlerter):
    """Send messages to Slack channels.

    Attributes:
        slack_token: The Slack token tied to the Slack account to be used.
    """

    slack_token: str
    default_slack_channel_id: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = SLACK_ALERTER_FLAVOR

    def _get_channel_id(self, config: Optional[BaseAlerterStepConfig]) -> str:
        """Get the Slack channel ID to be used by post/ask.

        Args:
            config: Optional runtime configuration.

        Returns:
            ID of the Slack channel to be used.

        Raises:
            RuntimeError: if config is not of type `BaseAlerterStepConfig`.
            ValueError: if a slack channel was neither defined in the config
                nor in the slack alerter component.
        """
        if not isinstance(config, BaseAlerterStepConfig):
            raise RuntimeError(
                "The config object must be of type `BaseAlerterStepConfig`."
            )
        if (
            isinstance(config, SlackAlerterConfig)
            and hasattr(config, "slack_channel_id")
            and config.slack_channel_id is not None
        ):
            return config.slack_channel_id
        if self.default_slack_channel_id is not None:
            return self.default_slack_channel_id
        raise ValueError(
            "Neither the `SlackAlerterConfig.slack_channel_id` in the runtime "
            "configuration, nor the `default_slack_channel_id` in the alerter "
            "stack component is specified. Please specify at least one."
        )

    def _get_approve_msg_options(
        self, config: Optional[BaseAlerterStepConfig]
    ) -> List[str]:
        """Define which messages will lead to approval during ask().

        Args:
            config: Optional runtime configuration.

        Returns:
            Set of messages that lead to approval in alerter.ask().
        """
        if (
            isinstance(config, SlackAlerterConfig)
            and hasattr(config, "approve_msg_options")
            and config.approve_msg_options is not None
        ):
            return config.approve_msg_options
        return DEFAULT_APPROVE_MSG_OPTIONS

    def _get_disapprove_msg_options(
        self, config: Optional[BaseAlerterStepConfig]
    ) -> List[str]:
        """Define which messages will lead to disapproval during ask().

        Args:
            config: Optional runtime configuration.

        Returns:
            Set of messages that lead to disapproval in alerter.ask().
        """
        if (
            isinstance(config, SlackAlerterConfig)
            and hasattr(config, "disapprove_msg_options")
            and config.disapprove_msg_options is not None
        ):
            return config.disapprove_msg_options
        return DEFAULT_DISAPPROVE_MSG_OPTIONS

    def post(
        self, message: str, config: Optional[BaseAlerterStepConfig]
    ) -> bool:
        """Post a message to a Slack channel.

        Args:
            message: Message to be posted.
            config: Optional runtime configuration.

        Returns:
            True if operation succeeded, else False
        """
        slack_channel_id = self._get_channel_id(config=config)
        client = WebClient(token=self.slack_token)
        try:
            response = client.chat_postMessage(
                channel=slack_channel_id,
                text=message,
            )
            return True
        except SlackApiError as error:
            response = error.response["error"]
            logger.error(f"SlackAlerter.post() failed: {response}")
            return False

    def ask(
        self, message: str, config: Optional[BaseAlerterStepConfig]
    ) -> bool:
        """Post a message to a Slack channel and wait for approval.

        Args:
            message: Initial message to be posted.
            config: Optional runtime configuration.

        Returns:
            True if a user approved the operation, else False
        """
        rtm = RTMClient(token=self.slack_token)
        slack_channel_id = self._get_channel_id(config=config)

        approved = False  # will be modified by handle()

        @RTMClient.run_on(event="hello")  # type: ignore
        def post_initial_message(**payload: Any) -> None:
            """Post an initial message in a channel and start listening.

            Args:
                payload: payload of the received Slack event.
            """
            web_client = payload["web_client"]
            web_client.chat_postMessage(channel=slack_channel_id, text=message)

        @RTMClient.run_on(event="message")  # type: ignore
        def handle(**payload: Any) -> None:
            """Listen / handle messages posted in the channel.

            Args:
                payload: payload of the received Slack event.
            """
            event = payload["data"]
            if event["channel"] == slack_channel_id:

                # approve request (return True)
                if event["text"] in self._get_approve_msg_options(config):
                    print(f"User {event['user']} approved on slack.")
                    nonlocal approved
                    approved = True
                    rtm.stop()  # type: ignore

                # disapprove request (return False)
                elif event["text"] in self._get_disapprove_msg_options(config):
                    print(f"User {event['user']} disapproved on slack.")
                    rtm.stop()  # type:ignore

        # start another thread until `rtm.stop()` is called in handle()
        rtm.start()

        return approved
ask(self, message, config)

Post a message to a Slack channel and wait for approval.

Parameters:

Name Type Description Default
message str

Initial message to be posted.

required
config Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepConfig]

Optional runtime configuration.

required

Returns:

Type Description
bool

True if a user approved the operation, else False

Source code in zenml/integrations/slack/alerters/slack_alerter.py
def ask(
    self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
    """Post a message to a Slack channel and wait for approval.

    Args:
        message: Initial message to be posted.
        config: Optional runtime configuration.

    Returns:
        True if a user approved the operation, else False
    """
    rtm = RTMClient(token=self.slack_token)
    slack_channel_id = self._get_channel_id(config=config)

    approved = False  # will be modified by handle()

    @RTMClient.run_on(event="hello")  # type: ignore
    def post_initial_message(**payload: Any) -> None:
        """Post an initial message in a channel and start listening.

        Args:
            payload: payload of the received Slack event.
        """
        web_client = payload["web_client"]
        web_client.chat_postMessage(channel=slack_channel_id, text=message)

    @RTMClient.run_on(event="message")  # type: ignore
    def handle(**payload: Any) -> None:
        """Listen / handle messages posted in the channel.

        Args:
            payload: payload of the received Slack event.
        """
        event = payload["data"]
        if event["channel"] == slack_channel_id:

            # approve request (return True)
            if event["text"] in self._get_approve_msg_options(config):
                print(f"User {event['user']} approved on slack.")
                nonlocal approved
                approved = True
                rtm.stop()  # type: ignore

            # disapprove request (return False)
            elif event["text"] in self._get_disapprove_msg_options(config):
                print(f"User {event['user']} disapproved on slack.")
                rtm.stop()  # type:ignore

    # start another thread until `rtm.stop()` is called in handle()
    rtm.start()

    return approved
post(self, message, config)

Post a message to a Slack channel.

Parameters:

Name Type Description Default
message str

Message to be posted.

required
config Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepConfig]

Optional runtime configuration.

required

Returns:

Type Description
bool

True if operation succeeded, else False

Source code in zenml/integrations/slack/alerters/slack_alerter.py
def post(
    self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
    """Post a message to a Slack channel.

    Args:
        message: Message to be posted.
        config: Optional runtime configuration.

    Returns:
        True if operation succeeded, else False
    """
    slack_channel_id = self._get_channel_id(config=config)
    client = WebClient(token=self.slack_token)
    try:
        response = client.chat_postMessage(
            channel=slack_channel_id,
            text=message,
        )
        return True
    except SlackApiError as error:
        response = error.response["error"]
        logger.error(f"SlackAlerter.post() failed: {response}")
        return False
SlackAlerterConfig (BaseAlerterStepConfig) pydantic-model

Slack alerter config.

Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerterConfig(BaseAlerterStepConfig):
    """Slack alerter config."""

    # The ID of the Slack channel to use for communication.
    slack_channel_id: Optional[str] = None

    # Set of messages that lead to approval in alerter.ask()
    approve_msg_options: Optional[List[str]] = None

    # Set of messages that lead to disapproval in alerter.ask()
    disapprove_msg_options: Optional[List[str]] = None

tensorflow special

Initialization for TensorFlow integration.

TensorflowIntegration (Integration)

Definition of Tensorflow integration for ZenML.

Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
    """Definition of Tensorflow integration for ZenML."""

    NAME = TENSORFLOW
    REQUIREMENTS = ["tensorflow==2.8.0", "tensorflow_io==0.24.0"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        # need to import this explicitly to load the Tensorflow file IO support
        # for S3 and other file systems
        import tensorflow_io  # type: ignore [import]

        from zenml.integrations.tensorflow import materializers  # noqa
        from zenml.integrations.tensorflow import services  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    # need to import this explicitly to load the Tensorflow file IO support
    # for S3 and other file systems
    import tensorflow_io  # type: ignore [import]

    from zenml.integrations.tensorflow import materializers  # noqa
    from zenml.integrations.tensorflow import services  # noqa

materializers special

Initialization for the TensorFlow materializers.

keras_materializer

Implementation of the TensorFlow Keras materializer.

KerasMaterializer (BaseMaterializer)

Materializer to read/write Keras models.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
    """Materializer to read/write Keras models."""

    ASSOCIATED_TYPES = (keras.Model,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> keras.Model:
        """Reads and returns a Keras model after copying it to temporary path.

        Args:
            data_type: The type of the data to read.

        Returns:
            A tf.keras.Model model.
        """
        super().handle_input(data_type)

        # Create a temporary directory to store the model
        temp_dir = tempfile.TemporaryDirectory()

        # Copy from artifact store to temporary directory
        io_utils.copy_dir(self.artifact.uri, temp_dir.name)

        # Load the model from the temporary directory
        model = keras.models.load_model(temp_dir.name)

        # Cleanup and return
        fileio.rmtree(temp_dir.name)

        return model

    def handle_return(self, model: keras.Model) -> None:
        """Writes a keras model to the artifact store.

        Args:
            model: A tf.keras.Model model.
        """
        super().handle_return(model)

        # Create a temporary directory to store the model
        temp_dir = tempfile.TemporaryDirectory()
        model.save(temp_dir.name)
        io_utils.copy_dir(temp_dir.name, self.artifact.uri)

        # Remove the temporary directory
        fileio.rmtree(temp_dir.name)
handle_input(self, data_type)

Reads and returns a Keras model after copying it to temporary path.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Model

A tf.keras.Model model.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_input(self, data_type: Type[Any]) -> keras.Model:
    """Reads and returns a Keras model after copying it to temporary path.

    Args:
        data_type: The type of the data to read.

    Returns:
        A tf.keras.Model model.
    """
    super().handle_input(data_type)

    # Create a temporary directory to store the model
    temp_dir = tempfile.TemporaryDirectory()

    # Copy from artifact store to temporary directory
    io_utils.copy_dir(self.artifact.uri, temp_dir.name)

    # Load the model from the temporary directory
    model = keras.models.load_model(temp_dir.name)

    # Cleanup and return
    fileio.rmtree(temp_dir.name)

    return model
handle_return(self, model)

Writes a keras model to the artifact store.

Parameters:

Name Type Description Default
model Model

A tf.keras.Model model.

required
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_return(self, model: keras.Model) -> None:
    """Writes a keras model to the artifact store.

    Args:
        model: A tf.keras.Model model.
    """
    super().handle_return(model)

    # Create a temporary directory to store the model
    temp_dir = tempfile.TemporaryDirectory()
    model.save(temp_dir.name)
    io_utils.copy_dir(temp_dir.name, self.artifact.uri)

    # Remove the temporary directory
    fileio.rmtree(temp_dir.name)
tf_dataset_materializer

Implementation of the TensorFlow dataset materializer.

TensorflowDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from tf.data.Dataset.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from tf.data.Dataset."""

    ASSOCIATED_TYPES = (tf.data.Dataset,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Any:
        """Reads data into tf.data.Dataset.

        Args:
            data_type: The type of the data to read.

        Returns:
            A tf.data.Dataset object.
        """
        super().handle_input(data_type)
        path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        return tf.data.experimental.load(path)

    def handle_return(self, dataset: tf.data.Dataset) -> None:
        """Persists a tf.data.Dataset object.

        Args:
            dataset: The dataset to persist.
        """
        super().handle_return(dataset)
        path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        tf.data.experimental.save(
            dataset, path, compression=None, shard_func=None
        )
handle_input(self, data_type)

Reads data into tf.data.Dataset.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Any

A tf.data.Dataset object.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Any:
    """Reads data into tf.data.Dataset.

    Args:
        data_type: The type of the data to read.

    Returns:
        A tf.data.Dataset object.
    """
    super().handle_input(data_type)
    path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    return tf.data.experimental.load(path)
handle_return(self, dataset)

Persists a tf.data.Dataset object.

Parameters:

Name Type Description Default
dataset DatasetV2

The dataset to persist.

required
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_return(self, dataset: tf.data.Dataset) -> None:
    """Persists a tf.data.Dataset object.

    Args:
        dataset: The dataset to persist.
    """
    super().handle_return(dataset)
    path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    tf.data.experimental.save(
        dataset, path, compression=None, shard_func=None
    )

services special

Initialization for TensorFlow services.

tensorboard_service

Implementation of the TensorBoard service.

TensorboardService (LocalDaemonService) pydantic-model

TensorBoard service.

This can be used to start a local TensorBoard server for one or more models.

Attributes:

Name Type Description
SERVICE_TYPE ClassVar[zenml.services.service_type.ServiceType]

a service type descriptor with information describing the TensorBoard service class

config TensorboardServiceConfig

service configuration

endpoint LocalDaemonServiceEndpoint

optional service endpoint

Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardService(LocalDaemonService):
    """TensorBoard service.

    This can be used to start a local TensorBoard server for one or more models.

    Attributes:
        SERVICE_TYPE: a service type descriptor with information describing
            the TensorBoard service class
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="tensorboard",
        type="visualization",
        flavor="tensorboard",
        description="TensorBoard visualization service",
    )

    config: TensorboardServiceConfig
    endpoint: LocalDaemonServiceEndpoint

    def __init__(
        self,
        config: Union[TensorboardServiceConfig, Dict[str, Any]],
        **attrs: Any,
    ) -> None:
        """Initialization for TensorBoard service.

        Args:
            config: service configuration
            **attrs: additional attributes
        """
        # ensure that the endpoint is created before the service is initialized
        # TODO [ENG-697]: implement a service factory or builder for TensorBoard
        #   deployment services
        if (
            isinstance(config, TensorboardServiceConfig)
            and "endpoint" not in attrs
        ):
            endpoint = LocalDaemonServiceEndpoint(
                config=LocalDaemonServiceEndpointConfig(
                    protocol=ServiceEndpointProtocol.HTTP,
                ),
                monitor=HTTPEndpointHealthMonitor(
                    config=HTTPEndpointHealthMonitorConfig(
                        healthcheck_uri_path="",
                        use_head_request=True,
                    )
                ),
            )
            attrs["endpoint"] = endpoint
        super().__init__(config=config, **attrs)

    def run(self) -> None:
        """Initialize and run the TensorBoard server."""
        logger.info(
            "Starting TensorBoard service as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()

        try:
            tensorboard = program.TensorBoard(
                plugins=default.get_plugins(),
                subcommands=[uploader_subcommand.UploaderSubcommand()],
            )
            tensorboard.configure(
                logdir=self.config.logdir,
                port=self.endpoint.status.port,
                host="localhost",
                max_reload_threads=self.config.max_reload_threads,
                reload_interval=self.config.reload_interval,
            )
            tensorboard.main()
        except KeyboardInterrupt:
            logger.info(
                "TensorBoard service stopped. Resuming normal execution."
            )
__init__(self, config, **attrs) special

Initialization for TensorBoard service.

Parameters:

Name Type Description Default
config Union[zenml.integrations.tensorflow.services.tensorboard_service.TensorboardServiceConfig, Dict[str, Any]]

service configuration

required
**attrs Any

additional attributes

{}
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def __init__(
    self,
    config: Union[TensorboardServiceConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialization for TensorBoard service.

    Args:
        config: service configuration
        **attrs: additional attributes
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-697]: implement a service factory or builder for TensorBoard
    #   deployment services
    if (
        isinstance(config, TensorboardServiceConfig)
        and "endpoint" not in attrs
    ):
        endpoint = LocalDaemonServiceEndpoint(
            config=LocalDaemonServiceEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path="",
                    use_head_request=True,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
run(self)

Initialize and run the TensorBoard server.

Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def run(self) -> None:
    """Initialize and run the TensorBoard server."""
    logger.info(
        "Starting TensorBoard service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        tensorboard = program.TensorBoard(
            plugins=default.get_plugins(),
            subcommands=[uploader_subcommand.UploaderSubcommand()],
        )
        tensorboard.configure(
            logdir=self.config.logdir,
            port=self.endpoint.status.port,
            host="localhost",
            max_reload_threads=self.config.max_reload_threads,
            reload_interval=self.config.reload_interval,
        )
        tensorboard.main()
    except KeyboardInterrupt:
        logger.info(
            "TensorBoard service stopped. Resuming normal execution."
        )
TensorboardServiceConfig (LocalDaemonServiceConfig) pydantic-model

TensorBoard service configuration.

Attributes:

Name Type Description
logdir str

location of TensorBoard log files.

max_reload_threads int

the max number of threads that TensorBoard can use to reload runs. Each thread reloads one run at a time.

reload_interval int

how often the backend should load more data, in seconds. Set to 0 to load just once at startup.

Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardServiceConfig(LocalDaemonServiceConfig):
    """TensorBoard service configuration.

    Attributes:
        logdir: location of TensorBoard log files.
        max_reload_threads: the max number of threads that TensorBoard can use
            to reload runs. Each thread reloads one run at a time.
        reload_interval: how often the backend should load more data, in
            seconds. Set to 0 to load just once at startup.
    """

    logdir: str
    max_reload_threads: int = 1
    reload_interval: int = 5

steps special

Initialization for TensorFlow standard steps.

tensorflow_trainer

Implementation of a TensorFlow trainer step.

TensorflowBinaryClassifier (BaseTrainerStep)

A TensorFlow binary classifier.

This simple step implementation creates a simple tensorflow feedforward neural network and trains it on a given pd.DataFrame dataset.

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifier(BaseTrainerStep):
    """A TensorFlow binary classifier.

    This simple step implementation creates a simple tensorflow feedforward
    neural network and trains it on a given pd.DataFrame dataset.
    """

    def entrypoint(  # type: ignore[override]
        self,
        train_dataset: pd.DataFrame,
        validation_dataset: pd.DataFrame,
        config: TensorflowBinaryClassifierConfig,
    ) -> tf.keras.Model:
        """Main entrypoint for the tensorflow trainer.

        Args:
            train_dataset: pd.DataFrame, the training dataset
            validation_dataset: pd.DataFrame, the validation dataset
            config: the configuration of the step

        Returns:
            the trained tf.keras.Model
        """
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
        model.add(tf.keras.layers.Flatten())

        last_layer = config.layers.pop()
        for i, layer in enumerate(config.layers):
            model.add(tf.keras.layers.Dense(layer, activation="relu"))
        model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))

        model.compile(
            optimizer=tf.keras.optimizers.Adam(config.learning_rate),
            loss=tf.keras.losses.BinaryCrossentropy(),
            metrics=config.metrics,
        )

        train_target = train_dataset.pop(config.target_column)
        validation_target = validation_dataset.pop(config.target_column)
        model.fit(
            x=train_dataset,
            y=train_target,
            validation_data=(validation_dataset, validation_target),
            batch_size=config.batch_size,
            epochs=config.epochs,
        )
        model.summary()

        return model
CONFIG_CLASS (BaseTrainerConfig) pydantic-model

Config class for the tensorflow trainer.

target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
    """Config class for the tensorflow trainer.

    target_column: the name of the label column
    layers: the number of units in the fully connected layers
    input_shape: the shape of the input
    learning_rate: the learning rate
    metrics: the list of metrics to be computed
    epochs: the number of epochs
    batch_size: the size of the batch
    """

    target_column: str
    layers: List[int] = [256, 64, 1]
    input_shape: Tuple[int] = (8,)
    learning_rate: float = 0.001
    metrics: List[str] = ["accuracy"]
    epochs: int = 50
    batch_size: int = 8
entrypoint(self, train_dataset, validation_dataset, config)

Main entrypoint for the tensorflow trainer.

Parameters:

Name Type Description Default
train_dataset DataFrame

pd.DataFrame, the training dataset

required
validation_dataset DataFrame

pd.DataFrame, the validation dataset

required
config TensorflowBinaryClassifierConfig

the configuration of the step

required

Returns:

Type Description
Model

the trained tf.keras.Model

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
def entrypoint(  # type: ignore[override]
    self,
    train_dataset: pd.DataFrame,
    validation_dataset: pd.DataFrame,
    config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
    """Main entrypoint for the tensorflow trainer.

    Args:
        train_dataset: pd.DataFrame, the training dataset
        validation_dataset: pd.DataFrame, the validation dataset
        config: the configuration of the step

    Returns:
        the trained tf.keras.Model
    """
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
    model.add(tf.keras.layers.Flatten())

    last_layer = config.layers.pop()
    for i, layer in enumerate(config.layers):
        model.add(tf.keras.layers.Dense(layer, activation="relu"))
    model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))

    model.compile(
        optimizer=tf.keras.optimizers.Adam(config.learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=config.metrics,
    )

    train_target = train_dataset.pop(config.target_column)
    validation_target = validation_dataset.pop(config.target_column)
    model.fit(
        x=train_dataset,
        y=train_target,
        validation_data=(validation_dataset, validation_target),
        batch_size=config.batch_size,
        epochs=config.epochs,
    )
    model.summary()

    return model
TensorflowBinaryClassifierConfig (BaseTrainerConfig) pydantic-model

Config class for the tensorflow trainer.

target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
    """Config class for the tensorflow trainer.

    target_column: the name of the label column
    layers: the number of units in the fully connected layers
    input_shape: the shape of the input
    learning_rate: the learning rate
    metrics: the list of metrics to be computed
    epochs: the number of epochs
    batch_size: the size of the batch
    """

    target_column: str
    layers: List[int] = [256, 64, 1]
    input_shape: Tuple[int] = (8,)
    learning_rate: float = 0.001
    metrics: List[str] = ["accuracy"]
    epochs: int = 50
    batch_size: int = 8

visualizers special

Initialization for TensorFlow visualizer.

tensorboard_visualizer

Implementation of a TensorFlow visualizer step.

TensorboardVisualizer (BaseStepVisualizer)

The implementation of a Whylogs Visualizer.

Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
class TensorboardVisualizer(BaseStepVisualizer):
    """The implementation of a Whylogs Visualizer."""

    @classmethod
    def find_running_tensorboard_server(
        cls, logdir: str
    ) -> Optional[TensorBoardInfo]:
        """Find a local TensorBoard server instance.

        Finds when it is running for the supplied logdir location and return its
        TCP port.

        Args:
            logdir: The logdir location where the TensorBoard server is running.

        Returns:
            The TensorBoardInfo describing the running TensorBoard server or
            None if no server is running for the supplied logdir location.
        """
        for server in get_all():
            if (
                server.logdir == logdir
                and server.pid
                and psutil.pid_exists(server.pid)
            ):
                return server
        return None

    def visualize(
        self,
        object: StepView,
        height: int = 800,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Start a TensorBoard server.

        Allows for the visualization of all models logged as artifacts by the
        indicated step. The server will monitor and display all the models
        logged by past and future step runs.

        Args:
            object: StepView fetched from run.get_step().
            height: Height of the generated visualization.
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        for _, artifact_view in object.outputs.items():
            # filter out anything but model artifacts
            if artifact_view.type == ModelArtifact.TYPE_NAME:
                logdir = os.path.dirname(artifact_view.uri)

                # first check if a TensorBoard server is already running for
                # the same logdir location and use that one
                running_server = self.find_running_tensorboard_server(logdir)
                if running_server:
                    self.visualize_tensorboard(running_server.port, height)
                    return

                if sys.platform == "win32":
                    # Daemon service functionality is currently not supported on Windows
                    print(
                        "You can run:\n"
                        f"[italic green]    tensorboard --logdir {logdir}"
                        "[/italic green]\n"
                        "...to visualize the TensorBoard logs for your trained model."
                    )
                else:
                    # start a new TensorBoard server
                    service = TensorboardService(
                        TensorboardServiceConfig(
                            logdir=logdir,
                        )
                    )
                    service.start(timeout=20)
                    if service.endpoint.status.port:
                        self.visualize_tensorboard(
                            service.endpoint.status.port, height
                        )
                return

    def visualize_tensorboard(
        self,
        port: int,
        height: int,
    ) -> None:
        """Generate a visualization of a TensorBoard.

        Args:
            port: the TCP port where the TensorBoard server is listening for
                requests.
            height: Height of the generated visualization.
        """
        if Environment.in_notebook():

            notebook.display(port, height=height)
            return

        print(
            "You can visit:\n"
            f"[italic green]    http://localhost:{port}/[/italic green]\n"
            "...to visualize the TensorBoard logs for your trained model."
        )

    def stop(
        self,
        object: StepView,
    ) -> None:
        """Stop the TensorBoard server previously started for a pipeline step.

        Args:
            object: StepView fetched from run.get_step().
        """
        for _, artifact_view in object.outputs.items():
            # filter out anything but model artifacts
            if artifact_view.type == ModelArtifact.TYPE_NAME:
                logdir = os.path.dirname(artifact_view.uri)

                # first check if a TensorBoard server is already running for
                # the same logdir location and use that one
                running_server = self.find_running_tensorboard_server(logdir)
                if not running_server:
                    return

                logger.debug(
                    "Stopping tensorboard server with PID '%d' ...",
                    running_server.pid,
                )
                try:
                    p = psutil.Process(running_server.pid)
                except psutil.Error:
                    logger.error(
                        "Could not find process for PID '%d' ...",
                        running_server.pid,
                    )
                    continue
                p.kill()
                return
find_running_tensorboard_server(logdir) classmethod

Find a local TensorBoard server instance.

Finds when it is running for the supplied logdir location and return its TCP port.

Parameters:

Name Type Description Default
logdir str

The logdir location where the TensorBoard server is running.

required

Returns:

Type Description
Optional[tensorboard.manager.TensorBoardInfo]

The TensorBoardInfo describing the running TensorBoard server or None if no server is running for the supplied logdir location.

Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
@classmethod
def find_running_tensorboard_server(
    cls, logdir: str
) -> Optional[TensorBoardInfo]:
    """Find a local TensorBoard server instance.

    Finds when it is running for the supplied logdir location and return its
    TCP port.

    Args:
        logdir: The logdir location where the TensorBoard server is running.

    Returns:
        The TensorBoardInfo describing the running TensorBoard server or
        None if no server is running for the supplied logdir location.
    """
    for server in get_all():
        if (
            server.logdir == logdir
            and server.pid
            and psutil.pid_exists(server.pid)
        ):
            return server
    return None
stop(self, object)

Stop the TensorBoard server previously started for a pipeline step.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop(
    self,
    object: StepView,
) -> None:
    """Stop the TensorBoard server previously started for a pipeline step.

    Args:
        object: StepView fetched from run.get_step().
    """
    for _, artifact_view in object.outputs.items():
        # filter out anything but model artifacts
        if artifact_view.type == ModelArtifact.TYPE_NAME:
            logdir = os.path.dirname(artifact_view.uri)

            # first check if a TensorBoard server is already running for
            # the same logdir location and use that one
            running_server = self.find_running_tensorboard_server(logdir)
            if not running_server:
                return

            logger.debug(
                "Stopping tensorboard server with PID '%d' ...",
                running_server.pid,
            )
            try:
                p = psutil.Process(running_server.pid)
            except psutil.Error:
                logger.error(
                    "Could not find process for PID '%d' ...",
                    running_server.pid,
                )
                continue
            p.kill()
            return
visualize(self, object, height=800, *args, **kwargs)

Start a TensorBoard server.

Allows for the visualization of all models logged as artifacts by the indicated step. The server will monitor and display all the models logged by past and future step runs.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
height int

Height of the generated visualization.

800
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize(
    self,
    object: StepView,
    height: int = 800,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Start a TensorBoard server.

    Allows for the visualization of all models logged as artifacts by the
    indicated step. The server will monitor and display all the models
    logged by past and future step runs.

    Args:
        object: StepView fetched from run.get_step().
        height: Height of the generated visualization.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for _, artifact_view in object.outputs.items():
        # filter out anything but model artifacts
        if artifact_view.type == ModelArtifact.TYPE_NAME:
            logdir = os.path.dirname(artifact_view.uri)

            # first check if a TensorBoard server is already running for
            # the same logdir location and use that one
            running_server = self.find_running_tensorboard_server(logdir)
            if running_server:
                self.visualize_tensorboard(running_server.port, height)
                return

            if sys.platform == "win32":
                # Daemon service functionality is currently not supported on Windows
                print(
                    "You can run:\n"
                    f"[italic green]    tensorboard --logdir {logdir}"
                    "[/italic green]\n"
                    "...to visualize the TensorBoard logs for your trained model."
                )
            else:
                # start a new TensorBoard server
                service = TensorboardService(
                    TensorboardServiceConfig(
                        logdir=logdir,
                    )
                )
                service.start(timeout=20)
                if service.endpoint.status.port:
                    self.visualize_tensorboard(
                        service.endpoint.status.port, height
                    )
            return
visualize_tensorboard(self, port, height)

Generate a visualization of a TensorBoard.

Parameters:

Name Type Description Default
port int

the TCP port where the TensorBoard server is listening for requests.

required
height int

Height of the generated visualization.

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(
    self,
    port: int,
    height: int,
) -> None:
    """Generate a visualization of a TensorBoard.

    Args:
        port: the TCP port where the TensorBoard server is listening for
            requests.
        height: Height of the generated visualization.
    """
    if Environment.in_notebook():

        notebook.display(port, height=height)
        return

    print(
        "You can visit:\n"
        f"[italic green]    http://localhost:{port}/[/italic green]\n"
        "...to visualize the TensorBoard logs for your trained model."
    )
get_step(pipeline_name, step_name)

Get the StepView for the specified pipeline and step name.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required
step_name str

The name of the step.

required

Returns:

Type Description
StepView

The StepView for the specified pipeline and step name.

Exceptions:

Type Description
RuntimeError

If the step is not found.

Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def get_step(pipeline_name: str, step_name: str) -> StepView:
    """Get the StepView for the specified pipeline and step name.

    Args:
        pipeline_name: The name of the pipeline.
        step_name: The name of the step.

    Returns:
        The StepView for the specified pipeline and step name.

    Raises:
        RuntimeError: If the step is not found.
    """
    repo = Repository()
    pipeline = repo.get_pipeline(pipeline_name)
    if pipeline is None:
        raise RuntimeError(f"No pipeline with name `{pipeline_name}` was found")

    last_run = pipeline.runs[-1]
    step = last_run.get_step(step=step_name)
    if step is None:
        raise RuntimeError(
            f"No pipeline step with name `{step_name}` was found in "
            f"pipeline `{pipeline_name}`"
        )
    return cast(StepView, step)
stop_tensorboard_server(pipeline_name, step_name)

Stop the TensorBoard server previously started for a pipeline step.

Parameters:

Name Type Description Default
pipeline_name str

the name of the pipeline

required
step_name str

pipeline step name

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop_tensorboard_server(pipeline_name: str, step_name: str) -> None:
    """Stop the TensorBoard server previously started for a pipeline step.

    Args:
        pipeline_name: the name of the pipeline
        step_name: pipeline step name
    """
    step = get_step(pipeline_name, step_name)
    TensorboardVisualizer().stop(step)
visualize_tensorboard(pipeline_name, step_name)

Start a TensorBoard server.

Allows for the visualization of all models logged as output by the named pipeline step. The server will monitor and display all the models logged by past and future step runs.

Parameters:

Name Type Description Default
pipeline_name str

the name of the pipeline

required
step_name str

pipeline step name

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(pipeline_name: str, step_name: str) -> None:
    """Start a TensorBoard server.

    Allows for the visualization of all models logged as output by the named
    pipeline step. The server will monitor and display all the models logged by
    past and future step runs.

    Args:
        pipeline_name: the name of the pipeline
        step_name: pipeline step name
    """
    step = get_step(pipeline_name, step_name)
    TensorboardVisualizer().visualize(step)

utils

Utility functions for the integrations module.

get_integration_for_module(module_name)

Gets the integration class for a module inside an integration.

If the module given by module_name is not part of a ZenML integration, this method will return None. If it is part of a ZenML integration, it will return the integration class found inside the integration init file.

Parameters:

Name Type Description Default
module_name str

The name of the module to get the integration for.

required

Returns:

Type Description
Optional[Type[zenml.integrations.integration.Integration]]

The integration class for the module.

Source code in zenml/integrations/utils.py
def get_integration_for_module(
    module_name: str,
) -> Optional[Type[Integration]]:
    """Gets the integration class for a module inside an integration.

    If the module given by `module_name` is not part of a ZenML integration,
    this method will return `None`. If it is part of a ZenML integration,
    it will return the integration class found inside the integration
    __init__ file.

    Args:
        module_name: The name of the module to get the integration for.

    Returns:
        The integration class for the module.
    """
    integration_prefix = "zenml.integrations."
    if not module_name.startswith(integration_prefix):
        return None

    integration_module_name = ".".join(module_name.split(".", 3)[:3])
    try:
        integration_module = sys.modules[integration_module_name]
    except KeyError:
        integration_module = importlib.import_module(integration_module_name)

    for name, member in inspect.getmembers(integration_module):
        if (
            member is not Integration
            and isinstance(member, IntegrationMeta)
            and issubclass(member, Integration)
        ):
            return cast(Type[Integration], member)

    return None

get_requirements_for_module(module_name)

Gets requirements for a module inside an integration.

If the module given by module_name is not part of a ZenML integration, this method will return an empty list. If it is part of a ZenML integration, it will return the list of requirements specified inside the integration class found inside the integration init file.

Parameters:

Name Type Description Default
module_name str

The name of the module to get requirements for.

required

Returns:

Type Description
List[str]

A list of requirements for the module.

Source code in zenml/integrations/utils.py
def get_requirements_for_module(module_name: str) -> List[str]:
    """Gets requirements for a module inside an integration.

    If the module given by `module_name` is not part of a ZenML integration,
    this method will return an empty list. If it is part of a ZenML integration,
    it will return the list of requirements specified inside the integration
    class found inside the integration __init__ file.

    Args:
        module_name: The name of the module to get requirements for.

    Returns:
        A list of requirements for the module.
    """
    integration = get_integration_for_module(module_name)
    return integration.REQUIREMENTS if integration else []

vault special

Initialization for the Vault Secrets Manager integration.

The Vault secrets manager integration submodule provides a way to access the HashiCorp Vault secrets manager from within your ZenML pipeline runs.

VaultSecretsManagerIntegration (Integration)

Definition of HashiCorp Vault integration with ZenML.

Source code in zenml/integrations/vault/__init__.py
class VaultSecretsManagerIntegration(Integration):
    """Definition of HashiCorp Vault integration with ZenML."""

    NAME = VAULT
    REQUIREMENTS = ["hvac>=0.11.2"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Vault integration.

        Returns:
            List of stack component flavors.
        """
        return [
            FlavorWrapper(
                name=VAULT_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.vault.secrets_manager.VaultSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Vault integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors.

Source code in zenml/integrations/vault/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Vault integration.

    Returns:
        List of stack component flavors.
    """
    return [
        FlavorWrapper(
            name=VAULT_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.vault.secrets_manager.VaultSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        )
    ]

secrets_manager special

HashiCorp Vault Secrets Manager.

vault_secrets_manager

Implementation of the HashiCorp Vault Secrets Manager integration.

VaultSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the Vault secrets manager - Key/value Engine.

Attributes:

Name Type Description
url str

The url of the Vault server.

token str

The token to use to authenticate with Vault.

cert Optional[str]

The path to the certificate to use to authenticate with Vault.

verify Optional[str]

Whether to verify the certificate or not.

mount_point str

The mount point of the secrets manager.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
class VaultSecretsManager(BaseSecretsManager):
    """Class to interact with the Vault secrets manager - Key/value Engine.

    Attributes:
        url: The url of the Vault server.
        token: The token to use to authenticate with Vault.
        cert: The path to the certificate to use to authenticate with Vault.
        verify: Whether to verify the certificate or not.
        mount_point: The mount point of the secrets manager.
    """

    # Class configuration
    FLAVOR: ClassVar[str] = VAULT_SECRETS_MANAGER_FLAVOR
    SUPPORTS_SCOPING: ClassVar[bool] = True
    CLIENT: ClassVar[Any] = None

    url: str
    token: str
    mount_point: str
    cert: Optional[str]
    verify: Optional[str]

    @classmethod
    def _ensure_client_connected(cls, url: str, token: str) -> None:
        """Ensure the client is connected.

        This function initializes the client if it is not initialized.

        Args:
            url: The url of the Vault server.
            token: The token to use to authenticate with Vault.
        """
        if cls.CLIENT is None:
            # Create a Vault Secrets Manager client
            cls.CLIENT = hvac.Client(
                url=url,
                token=token,
            )

    def _ensure_client_is_authenticated(self) -> None:
        """Ensure the client is authenticated.

        Raises:
            RuntimeError: If the client is not initialized or authenticated.
        """
        self._ensure_client_connected(url=self.url, token=self.token)

        if not self.CLIENT.is_authenticated():
            raise RuntimeError(
                "There was an error authenticating with Vault. Please check "
                "your configuration."
            )
        else:
            pass

    @classmethod
    def _validate_scope(
        cls,
        scope: SecretsManagerScope,
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            cls.validate_secret_name_or_namespace(namespace)

    @classmethod
    def validate_secret_name_or_namespace(cls, name: str) -> None:
        """Validate a secret name or namespace.

        For compatibility across secret managers the secret names should contain
        only alphanumeric characters and the characters /_+=.@-. The `/`
        character is only used internally to delimit scopes.

        Args:
            name: the secret name or namespace

        Raises:
            ValueError: if the secret name or namespace is invalid
        """
        if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. Must contain "
                f"only alphanumeric characters and the characters _+=.@-."
            )

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: The secret to register.

        Raises:
            SecretExistsError: If the secret already exists.
        """
        self._ensure_client_is_authenticated()

        self.validate_secret_name_or_namespace(secret.name)

        try:
            self.get_secret(secret.name)
            raise SecretExistsError(
                f"A Secret with the name '{secret.name}' already " f"exists."
            )
        except KeyError:
            pass

        secret_path = self._get_scoped_secret_name(secret.name)
        secret_value = secret_to_dict(secret, encode=False)

        self.CLIENT.secrets.kv.v2.create_or_update_secret(
            path=secret_path,
            mount_point=self.mount_point,
            secret=secret_value,
        )

        logger.info("Created secret: %s", f"{secret_path}")
        logger.info("Added value to secret.")

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets the value of a secret.

        Args:
            secret_name: The name of the secret to get.

        Returns:
            The secret.

        Raises:
            KeyError: If the secret does not exist.
        """
        self._ensure_client_is_authenticated()

        secret_path = self._get_scoped_secret_name(secret_name)

        try:
            secret_items = (
                self.CLIENT.secrets.kv.v2.read_secret_version(
                    path=secret_path,
                    mount_point=self.mount_point,
                )
                .get("data", {})
                .get("data", {})
            )
        except InvalidPath as e:
            raise KeyError(e)

        zenml_schema_name = secret_items.pop(ZENML_SCHEMA_NAME)

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=zenml_schema_name
        )
        secret_items["name"] = secret_name
        return secret_schema(**secret_items)

    def get_all_secret_keys(self) -> List[str]:
        """List all secrets in Vault without any reformatting.

        This function tries to get all secrets from Vault and returns
        them as a list of strings (all secrets' names).

        Returns:
            A list of all secrets in the secrets manager.
        """
        self._ensure_client_is_authenticated()

        set_of_secrets: Set[str] = set()
        secret_path = "/".join(self._get_scope_path())
        try:
            secrets = self.CLIENT.secrets.kv.v2.list_secrets(
                path=secret_path, mount_point=self.mount_point
            )
        except hvac.exceptions.InvalidPath:
            logger.error(
                f"There are no secrets created within the scope `{secret_path}`"
            )
            return list(set_of_secrets)

        secrets_keys = secrets.get("data", {}).get("keys", [])
        for secret_key in secrets_keys:
            # vault scopes end with / and are not themselves secrets
            if "/" not in secret_key:
                set_of_secrets.add(secret_key)
        return list(set_of_secrets)

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: The secret to update.

        Raises:
            KeyError: If the secret does not exist.
        """
        self._ensure_client_is_authenticated()

        self.validate_secret_name_or_namespace(secret.name)

        if secret.name in self.get_all_secret_keys():
            secret_path = self._get_scoped_secret_name(secret.name)
            secret_value = secret_to_dict(secret, encode=False)

            self.CLIENT.secrets.kv.v2.create_or_update_secret(
                path=secret_path,
                mount_point=self.mount_point,
                secret=secret_value,
            )
        else:
            raise KeyError(
                f"A Secret with the name '{secret.name}'" f" does not exist."
            )

        logger.info("Updated secret: %s", secret_path)
        logger.info("Added value to secret.")

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: The name of the secret to delete.
        """
        self._ensure_client_is_authenticated()

        secret_path = self._get_scoped_secret_name(secret_name)

        self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
            path=secret_path,
            mount_point=self.mount_point,
        )

        logger.info("Deleted secret: %s", f"{secret_path}")

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_is_authenticated()

        for secret_name in self.get_all_secret_keys():
            self.delete_secret(secret_name)

        logger.info("Deleted all secrets.")
delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_is_authenticated()

    for secret_name in self.get_all_secret_keys():
        self.delete_secret(secret_name)

    logger.info("Deleted all secrets.")
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to delete.

required
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: The name of the secret to delete.
    """
    self._ensure_client_is_authenticated()

    secret_path = self._get_scoped_secret_name(secret_name)

    self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
        path=secret_path,
        mount_point=self.mount_point,
    )

    logger.info("Deleted secret: %s", f"{secret_path}")
get_all_secret_keys(self)

List all secrets in Vault without any reformatting.

This function tries to get all secrets from Vault and returns them as a list of strings (all secrets' names).

Returns:

Type Description
List[str]

A list of all secrets in the secrets manager.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """List all secrets in Vault without any reformatting.

    This function tries to get all secrets from Vault and returns
    them as a list of strings (all secrets' names).

    Returns:
        A list of all secrets in the secrets manager.
    """
    self._ensure_client_is_authenticated()

    set_of_secrets: Set[str] = set()
    secret_path = "/".join(self._get_scope_path())
    try:
        secrets = self.CLIENT.secrets.kv.v2.list_secrets(
            path=secret_path, mount_point=self.mount_point
        )
    except hvac.exceptions.InvalidPath:
        logger.error(
            f"There are no secrets created within the scope `{secret_path}`"
        )
        return list(set_of_secrets)

    secrets_keys = secrets.get("data", {}).get("keys", [])
    for secret_key in secrets_keys:
        # vault scopes end with / and are not themselves secrets
        if "/" not in secret_key:
            set_of_secrets.add(secret_key)
    return list(set_of_secrets)
get_secret(self, secret_name)

Gets the value of a secret.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to get.

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

If the secret does not exist.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets the value of a secret.

    Args:
        secret_name: The name of the secret to get.

    Returns:
        The secret.

    Raises:
        KeyError: If the secret does not exist.
    """
    self._ensure_client_is_authenticated()

    secret_path = self._get_scoped_secret_name(secret_name)

    try:
        secret_items = (
            self.CLIENT.secrets.kv.v2.read_secret_version(
                path=secret_path,
                mount_point=self.mount_point,
            )
            .get("data", {})
            .get("data", {})
        )
    except InvalidPath as e:
        raise KeyError(e)

    zenml_schema_name = secret_items.pop(ZENML_SCHEMA_NAME)

    secret_schema = SecretSchemaClassRegistry.get_class(
        secret_schema=zenml_schema_name
    )
    secret_items["name"] = secret_name
    return secret_schema(**secret_items)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to register.

required

Exceptions:

Type Description
SecretExistsError

If the secret already exists.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: The secret to register.

    Raises:
        SecretExistsError: If the secret already exists.
    """
    self._ensure_client_is_authenticated()

    self.validate_secret_name_or_namespace(secret.name)

    try:
        self.get_secret(secret.name)
        raise SecretExistsError(
            f"A Secret with the name '{secret.name}' already " f"exists."
        )
    except KeyError:
        pass

    secret_path = self._get_scoped_secret_name(secret.name)
    secret_value = secret_to_dict(secret, encode=False)

    self.CLIENT.secrets.kv.v2.create_or_update_secret(
        path=secret_path,
        mount_point=self.mount_point,
        secret=secret_value,
    )

    logger.info("Created secret: %s", f"{secret_path}")
    logger.info("Added value to secret.")
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to update.

required

Exceptions:

Type Description
KeyError

If the secret does not exist.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: The secret to update.

    Raises:
        KeyError: If the secret does not exist.
    """
    self._ensure_client_is_authenticated()

    self.validate_secret_name_or_namespace(secret.name)

    if secret.name in self.get_all_secret_keys():
        secret_path = self._get_scoped_secret_name(secret.name)
        secret_value = secret_to_dict(secret, encode=False)

        self.CLIENT.secrets.kv.v2.create_or_update_secret(
            path=secret_path,
            mount_point=self.mount_point,
            secret=secret_value,
        )
    else:
        raise KeyError(
            f"A Secret with the name '{secret.name}'" f" does not exist."
        )

    logger.info("Updated secret: %s", secret_path)
    logger.info("Added value to secret.")
validate_secret_name_or_namespace(name) classmethod

Validate a secret name or namespace.

For compatibility across secret managers the secret names should contain only alphanumeric characters and the characters /_+=.@-. The / character is only used internally to delimit scopes.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
    """Validate a secret name or namespace.

    For compatibility across secret managers the secret names should contain
    only alphanumeric characters and the characters /_+=.@-. The `/`
    character is only used internally to delimit scopes.

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the characters _+=.@-."
        )

wandb special

Initialization for the wandb integration.

The wandb integrations currently enables you to use wandb tracking as a convenient way to visualize your experiment runs within the wandb ui.

WandbIntegration (Integration)

Definition of Plotly integration for ZenML.

Source code in zenml/integrations/wandb/__init__.py
class WandbIntegration(Integration):
    """Definition of Plotly integration for ZenML."""

    NAME = WANDB
    REQUIREMENTS = ["wandb>=0.12.12", "Pillow>=9.1.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Weights and Biases integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=WANDB_EXPERIMENT_TRACKER_FLAVOR,
                source="zenml.integrations.wandb.experiment_trackers.WandbExperimentTracker",
                type=StackComponentType.EXPERIMENT_TRACKER,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Weights and Biases integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/wandb/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Weights and Biases integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=WANDB_EXPERIMENT_TRACKER_FLAVOR,
            source="zenml.integrations.wandb.experiment_trackers.WandbExperimentTracker",
            type=StackComponentType.EXPERIMENT_TRACKER,
            integration=cls.NAME,
        )
    ]

experiment_trackers special

Initialization for the wandb experiment tracker.

wandb_experiment_tracker

Implementation for the wandb experiment tracker.

WandbExperimentTracker (BaseExperimentTracker) pydantic-model

Stores wandb configuration options.

ZenML should take care of configuring wandb for you, but should you still need access to the configuration inside your step you can do it using a step context:

from zenml.steps import StepContext

@enable_wandb
@step
def my_step(context: StepContext, ...)
    context.stack.experiment_tracker  # get the tracking_uri etc. from here

Attributes:

Name Type Description
entity Optional[str]

Name of an existing wandb entity.

project_name Optional[str]

Name of an existing wandb project to log to.

api_key str

API key to should be authorized to log to the configured wandb entity and project.

Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
class WandbExperimentTracker(BaseExperimentTracker):
    """Stores wandb configuration options.

    ZenML should take care of configuring wandb for you, but should you still
    need access to the configuration inside your step you can do it using a
    step context:
    ```python
    from zenml.steps import StepContext

    @enable_wandb
    @step
    def my_step(context: StepContext, ...)
        context.stack.experiment_tracker  # get the tracking_uri etc. from here
    ```

    Attributes:
        entity: Name of an existing wandb entity.
        project_name: Name of an existing wandb project to log to.
        api_key: API key to should be authorized to log to the configured wandb
            entity and project.
    """

    api_key: str
    entity: Optional[str] = None
    project_name: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = WANDB_EXPERIMENT_TRACKER_FLAVOR

    def prepare_step_run(self) -> None:
        """Sets the wandb api key."""
        os.environ[WANDB_API_KEY] = self.api_key

    @contextmanager
    def activate_wandb_run(
        self,
        run_name: str,
        tags: Tuple[str, ...] = (),
        settings: Optional[wandb.Settings] = None,
    ) -> Iterator[None]:
        """Activates a wandb run for the duration of this context manager.

        Anything logged to wandb that is run while this context manager is
        active will automatically log to the same wandb run configured by the
        run name passed as an argument to this function.

        Args:
            run_name: Name of the wandb run to create.
            tags: Tags to attach to the wandb run.
            settings: Additional settings for the wandb run.

        Yields:
            None
        """
        try:
            logger.info(
                f"Initializing wandb with project name: {self.project_name}, "
                f"run_name: {run_name}, entity: {self.entity}."
            )
            wandb.init(
                project=self.project_name,
                name=run_name,
                entity=self.entity,
                settings=settings,
                tags=tags,
            )
            yield
        finally:
            wandb.finish()
activate_wandb_run(self, run_name, tags=(), settings=None)

Activates a wandb run for the duration of this context manager.

Anything logged to wandb that is run while this context manager is active will automatically log to the same wandb run configured by the run name passed as an argument to this function.

Parameters:

Name Type Description Default
run_name str

Name of the wandb run to create.

required
tags Tuple[str, ...]

Tags to attach to the wandb run.

()
settings Optional[wandb.sdk.wandb_settings.Settings]

Additional settings for the wandb run.

None

Yields:

Type Description
Iterator[NoneType]

None

Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
@contextmanager
def activate_wandb_run(
    self,
    run_name: str,
    tags: Tuple[str, ...] = (),
    settings: Optional[wandb.Settings] = None,
) -> Iterator[None]:
    """Activates a wandb run for the duration of this context manager.

    Anything logged to wandb that is run while this context manager is
    active will automatically log to the same wandb run configured by the
    run name passed as an argument to this function.

    Args:
        run_name: Name of the wandb run to create.
        tags: Tags to attach to the wandb run.
        settings: Additional settings for the wandb run.

    Yields:
        None
    """
    try:
        logger.info(
            f"Initializing wandb with project name: {self.project_name}, "
            f"run_name: {run_name}, entity: {self.entity}."
        )
        wandb.init(
            project=self.project_name,
            name=run_name,
            entity=self.entity,
            settings=settings,
            tags=tags,
        )
        yield
    finally:
        wandb.finish()
prepare_step_run(self)

Sets the wandb api key.

Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def prepare_step_run(self) -> None:
    """Sets the wandb api key."""
    os.environ[WANDB_API_KEY] = self.api_key

wandb_step_decorator

Implementation for the wandb step decorator.

enable_wandb(_step=None, *, settings=None)

Decorator to enable wandb for a step function.

Apply this decorator to a ZenML pipeline step to enable wandb experiment tracking. The wandb tracking configuration (project name, experiment name, entity) will be automatically configured before the step code is executed, so the step can simply use the wandb module to log metrics and artifacts, like so:

@enable_wandb
@step
def tf_evaluator(
    x_test: np.ndarray,
    y_test: np.ndarray,
    model: tf.keras.Model,
) -> float:
    _, test_acc = model.evaluate(x_test, y_test, verbose=2)
    wandb.log_metric("val_accuracy", test_acc)
    return test_acc

You can also use this decorator with our class-based API like so:

@enable_wandb
class TFEvaluator(BaseStep):
    def entrypoint(
        self,
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        ...

All wandb artifacts and metrics logged from all the steps in a pipeline run are by default grouped under a single experiment named after the pipeline. To log wandb artifacts and metrics from a step in a separate wandb experiment, pass a custom experiment_name argument value to the decorator.

Parameters:

Name Type Description Default
_step Optional[~S]

The decorated step class.

None
settings Optional[wandb.sdk.wandb_settings.Settings]

wandb settings to use for the step.

None

Returns:

Type Description
Union[~S, Callable[[~S], ~S]]

The inner decorator which enhances the input step class with wandb tracking functionality

Source code in zenml/integrations/wandb/wandb_step_decorator.py
def enable_wandb(
    _step: Optional[S] = None, *, settings: Optional[wandb.Settings] = None
) -> Union[S, Callable[[S], S]]:
    """Decorator to enable wandb for a step function.

    Apply this decorator to a ZenML pipeline step to enable wandb experiment
    tracking. The wandb tracking configuration (project name, experiment name,
    entity) will be automatically configured before the step code is executed,
    so the step can simply use the `wandb` module to log metrics and artifacts,
    like so:

    ```python
    @enable_wandb
    @step
    def tf_evaluator(
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        _, test_acc = model.evaluate(x_test, y_test, verbose=2)
        wandb.log_metric("val_accuracy", test_acc)
        return test_acc
    ```

    You can also use this decorator with our class-based API like so:
    ```
    @enable_wandb
    class TFEvaluator(BaseStep):
        def entrypoint(
            self,
            x_test: np.ndarray,
            y_test: np.ndarray,
            model: tf.keras.Model,
        ) -> float:
            ...
    ```

    All wandb artifacts and metrics logged from all the steps in a pipeline
    run are by default grouped under a single experiment named after the
    pipeline. To log wandb artifacts and metrics from a step in a separate
    wandb experiment, pass a custom `experiment_name` argument value to the
    decorator.

    Args:
        _step: The decorated step class.
        settings: wandb settings to use for the step.

    Returns:
        The inner decorator which enhances the input step class with wandb
        tracking functionality
    """

    def inner_decorator(_step: S) -> S:
        """Inner decorator for step enable_wandb.

        Args:
            _step: The decorated step class.

        Returns:
            The decorated step class.

        Raises:
            RuntimeError: If the decorator is not being applied to a ZenML step
                decorated function or a BaseStep subclass.
        """
        logger.debug(
            "Applying 'enable_wandb' decorator to step %s", _step.__name__
        )
        if not issubclass(_step, BaseStep):
            raise RuntimeError(
                "The `enable_wandb` decorator can only be applied to a ZenML "
                "`step` decorated function or a BaseStep subclass."
            )
        source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
        new_entrypoint = wandb_step_entrypoint(
            settings=settings,
        )(source_fn)
        if _step._created_by_functional_api():
            # If the step was created by the functional API, the old entrypoint
            # was a static method -> make sure the new one is as well
            new_entrypoint = staticmethod(new_entrypoint)

        setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
        return _step

    if _step is None:
        return inner_decorator
    else:
        return inner_decorator(_step)
wandb_step_entrypoint(settings=None)

Decorator for a step entrypoint to enable wandb.

Parameters:

Name Type Description Default
settings Optional[wandb.sdk.wandb_settings.Settings]

wandb settings to use for the step.

None

Returns:

Type Description
Callable[[~F], ~F]

the input function enhanced with wandb profiling functionality

Source code in zenml/integrations/wandb/wandb_step_decorator.py
def wandb_step_entrypoint(
    settings: Optional[wandb.Settings] = None,
) -> Callable[[F], F]:
    """Decorator for a step entrypoint to enable wandb.

    Args:
        settings: wandb settings to use for the step.

    Returns:
        the input function enhanced with wandb profiling functionality
    """

    def inner_decorator(func: F) -> F:
        """Inner decorator for step entrypoint.

        Args:
            func: The decorated function.

        Returns:
            the input function enhanced with wandb profiling functionality
        """
        logger.debug(
            "Applying 'wandb_step_entrypoint' decorator to step entrypoint %s",
            func.__name__,
        )

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa
            """Wrapper function for decorator.

            Args:
                *args: positional arguments to the decorated function.
                **kwargs: keyword arguments to the decorated function.

            Returns:
                The return value of the decorated function.

            Raises:
                ValueError: if the active stack has no active experiment tracker.
            """
            logger.debug(
                "Setting up wandb backend before running step entrypoint %s",
                func.__name__,
            )
            step_env = Environment().step_environment
            run_name = f"{step_env.pipeline_run_id}_{step_env.step_name}"
            tags = (step_env.pipeline_name, step_env.pipeline_run_id)

            experiment_tracker = Repository(  # type: ignore[call-arg]
                skip_repository_check=True
            ).active_stack.experiment_tracker

            if not isinstance(experiment_tracker, WandbExperimentTracker):
                raise ValueError(
                    "The active stack needs to have a wandb experiment tracker "
                    "component registered to be able to track experiments "
                    "using wandb. You can create a new stack with a wandb "
                    "experiment tracker component or update your existing "
                    "stack to add this component, e.g.:\n\n"
                    "  'zenml experiment-tracker register wandb_tracker "
                    "--type=wandb --entity=<WANDB_ENTITY> --project_name="
                    "<WANDB_PROJECT_NAME> --api_key=<WANDB_API_KEY>'\n"
                    "  'zenml stack register stack-name -e wandb_tracker ...'\n"
                )

            with experiment_tracker.activate_wandb_run(
                run_name=run_name,
                tags=tags,
                settings=settings,
            ):
                return func(*args, **kwargs)

        return cast(F, wrapper)

    return inner_decorator

whylogs special

Initialization of the whylogs integration.

WhylogsIntegration (Integration)

Definition of whylogs integration for ZenML.

Source code in zenml/integrations/whylogs/__init__.py
class WhylogsIntegration(Integration):
    """Definition of [whylogs](https://github.com/whylabs/whylogs) integration for ZenML."""

    NAME = WHYLOGS
    REQUIREMENTS = ["whylogs[viz]~=1.0.5", "whylogs[whylabs]~=1.0.5"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.whylogs import materializers  # noqa
        from zenml.integrations.whylogs import secret_schemas  # noqa
        from zenml.integrations.whylogs import visualizers  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Great Expectations integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=WHYLOGS_DATA_VALIDATOR_FLAVOR,
                source="zenml.integrations.whylogs.data_validators.WhylogsDataValidator",
                type=StackComponentType.DATA_VALIDATOR,
                integration=cls.NAME,
            ),
        ]
activate() classmethod

Activates the integration.

Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.whylogs import materializers  # noqa
    from zenml.integrations.whylogs import secret_schemas  # noqa
    from zenml.integrations.whylogs import visualizers  # noqa
flavors() classmethod

Declare the stack component flavors for the Great Expectations integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Great Expectations integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=WHYLOGS_DATA_VALIDATOR_FLAVOR,
            source="zenml.integrations.whylogs.data_validators.WhylogsDataValidator",
            type=StackComponentType.DATA_VALIDATOR,
            integration=cls.NAME,
        ),
    ]

data_validators special

Initialization of the whylogs data validator for ZenML.

whylogs_data_validator

Implementation of the whylogs data validator.

WhylogsDataValidator (BaseDataValidator, AuthenticationMixin) pydantic-model

Whylogs data validator stack component.

Attributes:

Name Type Description
authentication_secret Optional[str]

Optional ZenML secret with Whylabs credentials. If configured, all the data profiles returned by all pipeline steps will automatically be uploaded to Whylabs in addition to being stored in the ZenML Artifact Store.

Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
class WhylogsDataValidator(BaseDataValidator, AuthenticationMixin):
    """Whylogs data validator stack component.

    Attributes:
        authentication_secret: Optional ZenML secret with Whylabs credentials.
            If configured, all the data profiles returned by all pipeline steps
            will automatically be uploaded to Whylabs in addition to being
            stored in the ZenML Artifact Store.
    """

    # Class Configuration
    FLAVOR: ClassVar[str] = WHYLOGS_DATA_VALIDATOR_FLAVOR
    NAME: ClassVar[str] = "whylogs"

    def data_profiling(
        self,
        dataset: pd.DataFrame,
        comparison_dataset: Optional[pd.DataFrame] = None,
        profile_list: Optional[Sequence[str]] = None,
        dataset_timestamp: Optional[datetime.datetime] = None,
        **kwargs: Any,
    ) -> DatasetProfileView:
        """Analyze a dataset and generate a data profile with whylogs.

        Args:
            dataset: Target dataset to be profiled.
            comparison_dataset: Optional dataset to be used for data profiles
                that require a baseline for comparison (e.g data drift profiles).
            profile_list: Optional list identifying the categories of whylogs
                data profiles to be generated (unused).
            dataset_timestamp: timestamp to associate with the generated
                dataset profile (Optional). The current time is used if not
                supplied.
            **kwargs: Extra keyword arguments (unused).

        Returns:
            A whylogs profile view object.
        """
        results = why.log(pandas=dataset)
        profile = results.profile()
        dataset_timestamp = dataset_timestamp or datetime.datetime.utcnow()
        profile.set_dataset_timestamp(dataset_timestamp=dataset_timestamp)
        return profile.view()

    def upload_profile_view(
        self, profile_view: DatasetProfileView, dataset_id: Optional[str] = None
    ) -> None:
        """Upload a whylogs data profile view to Whylabs, if configured to do so.

        Args:
            profile_view: Whylogs profile view to upload.
            dataset_id: Optional dataset identifier to use for the uploaded
                data profile. If omitted, a dataset identifier will be retrieved
                using other means, in order:
                    * the default dataset identifier configured in the Data
                    Validator secret
                    * a dataset ID will be generated automatically based on the
                    current pipeline/step information.

        Raises:
            ValueError: If the dataset ID was not provided and could not be
                retrieved or inferred from other sources.
        """
        secret = self.get_authentication_secret(
            expected_schema_type=WhylabsSecretSchema
        )
        if not secret:
            return

        dataset_id = dataset_id or secret.whylabs_default_dataset_id

        if not dataset_id:
            # use the current pipeline name and the step name to generate a
            # unique dataset name
            try:
                # get pipeline name and step name
                step_env = cast(
                    StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
                )
                dataset_id = f"{step_env.pipeline_name}_{step_env.step_name}"
            except KeyError:
                raise ValueError(
                    "A dataset ID was not specified and could not be "
                    "generated from the current pipeline and step name."
                )

        # Instantiate WhyLabs Writer
        writer = WhyLabsWriter(
            org_id=secret.whylabs_default_org_id,
            api_key=secret.whylabs_api_key,
            dataset_id=dataset_id,
        )

        # pass a profile view to the writer's write method
        writer.write(profile=profile_view)
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, dataset_timestamp=None, **kwargs)

Analyze a dataset and generate a data profile with whylogs.

Parameters:

Name Type Description Default
dataset DataFrame

Target dataset to be profiled.

required
comparison_dataset Optional[pandas.core.frame.DataFrame]

Optional dataset to be used for data profiles that require a baseline for comparison (e.g data drift profiles).

None
profile_list Optional[Sequence[str]]

Optional list identifying the categories of whylogs data profiles to be generated (unused).

None
dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

None
**kwargs Any

Extra keyword arguments (unused).

{}

Returns:

Type Description
DatasetProfileView

A whylogs profile view object.

Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def data_profiling(
    self,
    dataset: pd.DataFrame,
    comparison_dataset: Optional[pd.DataFrame] = None,
    profile_list: Optional[Sequence[str]] = None,
    dataset_timestamp: Optional[datetime.datetime] = None,
    **kwargs: Any,
) -> DatasetProfileView:
    """Analyze a dataset and generate a data profile with whylogs.

    Args:
        dataset: Target dataset to be profiled.
        comparison_dataset: Optional dataset to be used for data profiles
            that require a baseline for comparison (e.g data drift profiles).
        profile_list: Optional list identifying the categories of whylogs
            data profiles to be generated (unused).
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
        **kwargs: Extra keyword arguments (unused).

    Returns:
        A whylogs profile view object.
    """
    results = why.log(pandas=dataset)
    profile = results.profile()
    dataset_timestamp = dataset_timestamp or datetime.datetime.utcnow()
    profile.set_dataset_timestamp(dataset_timestamp=dataset_timestamp)
    return profile.view()
upload_profile_view(self, profile_view, dataset_id=None)

Upload a whylogs data profile view to Whylabs, if configured to do so.

Parameters:

Name Type Description Default
profile_view DatasetProfileView

Whylogs profile view to upload.

required
dataset_id Optional[str]

Optional dataset identifier to use for the uploaded data profile. If omitted, a dataset identifier will be retrieved using other means, in order: * the default dataset identifier configured in the Data Validator secret * a dataset ID will be generated automatically based on the current pipeline/step information.

None

Exceptions:

Type Description
ValueError

If the dataset ID was not provided and could not be retrieved or inferred from other sources.

Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def upload_profile_view(
    self, profile_view: DatasetProfileView, dataset_id: Optional[str] = None
) -> None:
    """Upload a whylogs data profile view to Whylabs, if configured to do so.

    Args:
        profile_view: Whylogs profile view to upload.
        dataset_id: Optional dataset identifier to use for the uploaded
            data profile. If omitted, a dataset identifier will be retrieved
            using other means, in order:
                * the default dataset identifier configured in the Data
                Validator secret
                * a dataset ID will be generated automatically based on the
                current pipeline/step information.

    Raises:
        ValueError: If the dataset ID was not provided and could not be
            retrieved or inferred from other sources.
    """
    secret = self.get_authentication_secret(
        expected_schema_type=WhylabsSecretSchema
    )
    if not secret:
        return

    dataset_id = dataset_id or secret.whylabs_default_dataset_id

    if not dataset_id:
        # use the current pipeline name and the step name to generate a
        # unique dataset name
        try:
            # get pipeline name and step name
            step_env = cast(
                StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
            )
            dataset_id = f"{step_env.pipeline_name}_{step_env.step_name}"
        except KeyError:
            raise ValueError(
                "A dataset ID was not specified and could not be "
                "generated from the current pipeline and step name."
            )

    # Instantiate WhyLabs Writer
    writer = WhyLabsWriter(
        org_id=secret.whylabs_default_org_id,
        api_key=secret.whylabs_api_key,
        dataset_id=dataset_id,
    )

    # pass a profile view to the writer's write method
    writer.write(profile=profile_view)

materializers special

Initialization of the whylogs materializer.

whylogs_materializer

Implementation of the whylogs materializer.

WhylogsMaterializer (BaseMaterializer)

Materializer to read/write whylogs dataset profile views.

Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
class WhylogsMaterializer(BaseMaterializer):
    """Materializer to read/write whylogs dataset profile views."""

    ASSOCIATED_TYPES = (DatasetProfileView,)
    ASSOCIATED_ARTIFACT_TYPES = (StatisticsArtifact,)

    def handle_input(self, data_type: Type[Any]) -> DatasetProfileView:
        """Reads and returns a whylogs dataset profile view.

        Args:
            data_type: The type of the data to read.

        Returns:
            A loaded whylogs dataset profile view object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        profile_view = DatasetProfileView.read(temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)

        return profile_view

    def handle_return(self, profile_view: DatasetProfileView) -> None:
        """Writes a whylogs dataset profile view.

        Args:
            profile_view: A whylogs dataset profile view object.
        """
        super().handle_return(profile_view)
        filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)

        profile_view.write(temp_file)

        # Copy it into artifact store
        fileio.copy(temp_file, filepath)
        fileio.rmtree(temp_dir)

        # Use the data validator to upload the profile view to Whylabs,
        # if configured to do so. This logic is only enabled if the pipeline
        # step was decorated with the `enable_whylabs` decorator
        whylabs_enabled = os.environ.get(WHYLABS_LOGGING_ENABLED_ENV)
        if not whylabs_enabled:
            return
        dataset_id = os.environ.get(WHYLABS_DATASET_ID_ENV)
        data_validator = cast(
            WhylogsDataValidator,
            WhylogsDataValidator.get_active_data_validator(),
        )
        data_validator.upload_profile_view(profile_view, dataset_id=dataset_id)
handle_input(self, data_type)

Reads and returns a whylogs dataset profile view.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
DatasetProfileView

A loaded whylogs dataset profile view object.

Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_input(self, data_type: Type[Any]) -> DatasetProfileView:
    """Reads and returns a whylogs dataset profile view.

    Args:
        data_type: The type of the data to read.

    Returns:
        A loaded whylogs dataset profile view object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    profile_view = DatasetProfileView.read(temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)

    return profile_view
handle_return(self, profile_view)

Writes a whylogs dataset profile view.

Parameters:

Name Type Description Default
profile_view DatasetProfileView

A whylogs dataset profile view object.

required
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_return(self, profile_view: DatasetProfileView) -> None:
    """Writes a whylogs dataset profile view.

    Args:
        profile_view: A whylogs dataset profile view object.
    """
    super().handle_return(profile_view)
    filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)

    profile_view.write(temp_file)

    # Copy it into artifact store
    fileio.copy(temp_file, filepath)
    fileio.rmtree(temp_dir)

    # Use the data validator to upload the profile view to Whylabs,
    # if configured to do so. This logic is only enabled if the pipeline
    # step was decorated with the `enable_whylabs` decorator
    whylabs_enabled = os.environ.get(WHYLABS_LOGGING_ENABLED_ENV)
    if not whylabs_enabled:
        return
    dataset_id = os.environ.get(WHYLABS_DATASET_ID_ENV)
    data_validator = cast(
        WhylogsDataValidator,
        WhylogsDataValidator.get_active_data_validator(),
    )
    data_validator.upload_profile_view(profile_view, dataset_id=dataset_id)

secret_schemas special

Initialization for the Whylabs secret schema.

This schema can be used to configure a ZenML secret to authenticate ZenML to use the Whylabs platform to automatically log all whylogs data profiles generated and by pipeline steps.

whylabs_secret_schema

Implementation for Seldon secret schemas.

WhylabsSecretSchema (BaseSecretSchema) pydantic-model

Whylabs credentials.

Attributes:

Name Type Description
whylabs_default_org_id str

the Whylabs organization ID.

whylabs_api_key str

Whylabs API key.

whylabs_default_dataset_id Optional[str]

default Whylabs dataset ID to use when logging data profiles.

Source code in zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py
class WhylabsSecretSchema(BaseSecretSchema):
    """Whylabs credentials.

    Attributes:
        whylabs_default_org_id: the Whylabs organization ID.
        whylabs_api_key: Whylabs API key.
        whylabs_default_dataset_id: default Whylabs dataset ID to use when
            logging data profiles.
    """

    TYPE: ClassVar[str] = WHYLABS_SECRET_SCHEMA_TYPE

    whylabs_default_org_id: str
    whylabs_api_key: str
    whylabs_default_dataset_id: Optional[str] = None

steps special

Initialization of the whylogs steps.

whylogs_profiler

Implementation of the whylogs profiler step.

WhylogsProfilerConfig (BaseAnalyzerConfig) pydantic-model

Config class for the WhylogsProfiler step.

Attributes:

Name Type Description
dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
    """Config class for the WhylogsProfiler step.

    Attributes:
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
    """

    dataset_timestamp: Optional[datetime.datetime]
WhylogsProfilerStep (BaseAnalyzerStep)

Generates a whylogs data profile from a given pd.DataFrame.

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerStep(BaseAnalyzerStep):
    """Generates a whylogs data profile from a given pd.DataFrame."""

    @staticmethod
    def entrypoint(  # type: ignore[override]
        dataset: pd.DataFrame,
        config: WhylogsProfilerConfig,
    ) -> DatasetProfileView:
        """Main entrypoint function for the whylogs profiler.

        Args:
            dataset: pd.DataFrame, the given dataset
            config: the configuration of the step

        Returns:
            whylogs profile with statistics generated for the input dataset
        """
        data_validator = cast(
            WhylogsDataValidator,
            WhylogsDataValidator.get_active_data_validator(),
        )
        return data_validator.data_profiling(
            dataset, dataset_timestamp=config.dataset_timestamp
        )
CONFIG_CLASS (BaseAnalyzerConfig) pydantic-model

Config class for the WhylogsProfiler step.

Attributes:

Name Type Description
dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
    """Config class for the WhylogsProfiler step.

    Attributes:
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
    """

    dataset_timestamp: Optional[datetime.datetime]
entrypoint(dataset, config) staticmethod

Main entrypoint function for the whylogs profiler.

Parameters:

Name Type Description Default
dataset DataFrame

pd.DataFrame, the given dataset

required
config WhylogsProfilerConfig

the configuration of the step

required

Returns:

Type Description
DatasetProfileView

whylogs profile with statistics generated for the input dataset

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
@staticmethod
def entrypoint(  # type: ignore[override]
    dataset: pd.DataFrame,
    config: WhylogsProfilerConfig,
) -> DatasetProfileView:
    """Main entrypoint function for the whylogs profiler.

    Args:
        dataset: pd.DataFrame, the given dataset
        config: the configuration of the step

    Returns:
        whylogs profile with statistics generated for the input dataset
    """
    data_validator = cast(
        WhylogsDataValidator,
        WhylogsDataValidator.get_active_data_validator(),
    )
    return data_validator.data_profiling(
        dataset, dataset_timestamp=config.dataset_timestamp
    )
whylogs_profiler_step(step_name, config, dataset_id=None)

Shortcut function to create a new instance of the WhylogsProfilerStep step.

The returned WhylogsProfilerStep can be used in a pipeline to generate a whylogs DatasetProfileView from a given pd.DataFrame and save it as an artifact.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config WhylogsProfilerConfig

The step configuration

required
dataset_id Optional[str]

Optional dataset ID to use to upload the profile to Whylabs.

None

Returns:

Type Description
BaseStep

a WhylogsProfilerStep step instance

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def whylogs_profiler_step(
    step_name: str,
    config: WhylogsProfilerConfig,
    dataset_id: Optional[str] = None,
) -> BaseStep:
    """Shortcut function to create a new instance of the WhylogsProfilerStep step.

    The returned WhylogsProfilerStep can be used in a pipeline to generate a
    whylogs DatasetProfileView from a given pd.DataFrame and save it as an
    artifact.

    Args:
        step_name: The name of the step
        config: The step configuration
        dataset_id: Optional dataset ID to use to upload the profile to Whylabs.

    Returns:
        a WhylogsProfilerStep step instance
    """
    step = enable_whylabs(dataset_id=dataset_id)(
        clone_step(WhylogsProfilerStep, step_name)
    )
    return step(config=config)

visualizers special

Initialization of the whylogs visualizer.

whylogs_visualizer

Implementation of the whylogs visualizer step.

WhylogsVisualizer (BaseStepVisualizer)

The implementation of a Whylogs Visualizer.

Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsVisualizer(BaseStepVisualizer):
    """The implementation of a Whylogs Visualizer."""

    def visualize(
        self,
        object: StepView,
        reference_step_view: Optional[StepView] = None,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Visualize whylogs dataset profiles present as outputs in the step view.

        Args:
            object: StepView fetched from run.get_step().
            reference_step_view: second StepView fetched from run.get_step() to
                use as a reference to visualize data drift
            *args: additional positional arguments to pass to the visualize
                method
            **kwargs: additional keyword arguments to pass to the visualize
                method
        """

        def extract_profile(
            step_view: StepView,
        ) -> Optional[DatasetProfileView]:
            """Extract a whylogs DatasetProfileView from a step view.

            Args:
                step_view: a step view

            Returns:
                A whylogs DatasetProfileView object loaded from the step view,
                if one could be found, otherwise None.
            """
            whylogs_artifact_datatype = (
                f"{DatasetProfileView.__module__}.{DatasetProfileView.__name__}"
            )
            for _, artifact_view in step_view.outputs.items():
                # filter out anything but whylogs dataset profile artifacts
                if artifact_view.data_type == whylogs_artifact_datatype:
                    profile = artifact_view.read()
                    return cast(DatasetProfileView, profile)
            return None

        profile = extract_profile(object)
        reference_profile: Optional[DatasetProfileView] = None
        if reference_step_view:
            reference_profile = extract_profile(reference_step_view)

        self.visualize_profile(profile, reference_profile)

    def visualize_profile(
        self,
        profile: DatasetProfileView,
        reference_profile: Optional[DatasetProfileView] = None,
    ) -> None:
        """Generate a visualization of one or two whylogs dataset profile.

        Args:
            profile: whylogs DatasetProfileView to visualize
            reference_profile: second optional DatasetProfileView to use to
                generate a data drift visualization
        """
        # currently, whylogs doesn't support visualizing a single profile, so
        # we trick it by using the same profile twice, both as reference and
        # target, in a drift report
        reference_profile = reference_profile or profile
        visualization = NotebookProfileVisualizer()
        visualization.set_profiles(
            target_profile_view=profile,
            reference_profile_view=reference_profile,
        )
        rendered_html = visualization.summary_drift_report()

        if Environment.in_notebook():
            from IPython.core.display import display

            display(rendered_html)
            for column in sorted(list(profile.get_columns().keys())):
                display(visualization.double_histogram(feature_name=column))
        else:
            logger.warning(
                "The magic functions are only usable in a Jupyter notebook."
            )
            with tempfile.NamedTemporaryFile(
                mode="w", delete=False, suffix=".html", encoding="utf-8"
            ) as f:
                f.write(rendered_html.data)
                url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
visualize(self, object, reference_step_view=None, *args, **kwargs)

Visualize whylogs dataset profiles present as outputs in the step view.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
reference_step_view Optional[zenml.post_execution.step.StepView]

second StepView fetched from run.get_step() to use as a reference to visualize data drift

None
*args Any

additional positional arguments to pass to the visualize method

()
**kwargs Any

additional keyword arguments to pass to the visualize method

{}
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize(
    self,
    object: StepView,
    reference_step_view: Optional[StepView] = None,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Visualize whylogs dataset profiles present as outputs in the step view.

    Args:
        object: StepView fetched from run.get_step().
        reference_step_view: second StepView fetched from run.get_step() to
            use as a reference to visualize data drift
        *args: additional positional arguments to pass to the visualize
            method
        **kwargs: additional keyword arguments to pass to the visualize
            method
    """

    def extract_profile(
        step_view: StepView,
    ) -> Optional[DatasetProfileView]:
        """Extract a whylogs DatasetProfileView from a step view.

        Args:
            step_view: a step view

        Returns:
            A whylogs DatasetProfileView object loaded from the step view,
            if one could be found, otherwise None.
        """
        whylogs_artifact_datatype = (
            f"{DatasetProfileView.__module__}.{DatasetProfileView.__name__}"
        )
        for _, artifact_view in step_view.outputs.items():
            # filter out anything but whylogs dataset profile artifacts
            if artifact_view.data_type == whylogs_artifact_datatype:
                profile = artifact_view.read()
                return cast(DatasetProfileView, profile)
        return None

    profile = extract_profile(object)
    reference_profile: Optional[DatasetProfileView] = None
    if reference_step_view:
        reference_profile = extract_profile(reference_step_view)

    self.visualize_profile(profile, reference_profile)
visualize_profile(self, profile, reference_profile=None)

Generate a visualization of one or two whylogs dataset profile.

Parameters:

Name Type Description Default
profile DatasetProfileView

whylogs DatasetProfileView to visualize

required
reference_profile Optional[whylogs.core.view.dataset_profile_view.DatasetProfileView]

second optional DatasetProfileView to use to generate a data drift visualization

None
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize_profile(
    self,
    profile: DatasetProfileView,
    reference_profile: Optional[DatasetProfileView] = None,
) -> None:
    """Generate a visualization of one or two whylogs dataset profile.

    Args:
        profile: whylogs DatasetProfileView to visualize
        reference_profile: second optional DatasetProfileView to use to
            generate a data drift visualization
    """
    # currently, whylogs doesn't support visualizing a single profile, so
    # we trick it by using the same profile twice, both as reference and
    # target, in a drift report
    reference_profile = reference_profile or profile
    visualization = NotebookProfileVisualizer()
    visualization.set_profiles(
        target_profile_view=profile,
        reference_profile_view=reference_profile,
    )
    rendered_html = visualization.summary_drift_report()

    if Environment.in_notebook():
        from IPython.core.display import display

        display(rendered_html)
        for column in sorted(list(profile.get_columns().keys())):
            display(visualization.double_histogram(feature_name=column))
    else:
        logger.warning(
            "The magic functions are only usable in a Jupyter notebook."
        )
        with tempfile.NamedTemporaryFile(
            mode="w", delete=False, suffix=".html", encoding="utf-8"
        ) as f:
            f.write(rendered_html.data)
            url = f"file:///{f.name}"
        logger.info("Opening %s in a new browser.." % f.name)
        webbrowser.open(url, new=2)

whylabs_step_decorator

Implementation of the Whylabs step decorator.

enable_whylabs(_step=None, *, dataset_id=None)

Decorator to enable Whylabs profiling for a step function.

Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.

Note that you also need to have a whylogs Data Validator part of your active stack with the Whylabs credentials configured for this to have effect.

All the whylogs data profile views returned by the step will automatically be uploaded to the Whylabs platform if the active whylogs Data Validator component is configured with Whylabs credentials, e.g.:

import whylogs as why
from whylogs.core import DatasetProfileView
from zenml.integrations.whylogs.whylabs_step_decorator import enable_whylabs

@enable_whylabs(dataset_id="my_model")
@step
def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
    ...
    data = pd.DataFrame(...)
    results = why.log(pandas=dataset)
    profile = results.profile()
    ...
    return data, profile.view()

You can also use this decorator with our class-based API like so:

@enable_whylabs(dataset_id="my_model")
class DataLoader(BaseStep):
    def entrypoint(self) -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
        ...

Parameters:

Name Type Description Default
_step Optional[~S]

The decorated step class.

None
dataset_id Optional[str]

Optional dataset ID to use for the uploaded profile(s)

None

Returns:

Type Description
Union[~S, Callable[[~S], ~S]]

the inner decorator which enhances the input step class with Whylabs profiling functionality

Source code in zenml/integrations/whylogs/whylabs_step_decorator.py
def enable_whylabs(
    _step: Optional[S] = None,
    *,
    dataset_id: Optional[str] = None,
) -> Union[S, Callable[[S], S]]:
    """Decorator to enable Whylabs profiling for a step function.

    Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.

    Note that you also need to have a whylogs Data Validator part of your active
    stack with the Whylabs credentials configured for this to have effect.

    All the whylogs data profile views returned by the step will automatically
    be uploaded to the Whylabs platform if the active whylogs Data Validator
    component is configured with Whylabs credentials, e.g.:

    ```python
    import whylogs as why
    from whylogs.core import DatasetProfileView
    from zenml.integrations.whylogs.whylabs_step_decorator import enable_whylabs

    @enable_whylabs(dataset_id="my_model")
    @step
    def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
        ...
        data = pd.DataFrame(...)
        results = why.log(pandas=dataset)
        profile = results.profile()
        ...
        return data, profile.view()
    ```

    You can also use this decorator with our class-based API like so:
    ```
    @enable_whylabs(dataset_id="my_model")
    class DataLoader(BaseStep):
        def entrypoint(self) -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
            ...
    ```

    Args:
        _step: The decorated step class.
        dataset_id: Optional dataset ID to use for the uploaded profile(s)

    Returns:
        the inner decorator which enhances the input step class with Whylabs
        profiling functionality
    """

    def inner_decorator(_step: S) -> S:
        source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
        new_entrypoint = whylabs_entrypoint(dataset_id)(source_fn)
        if _step._created_by_functional_api():
            # If the step was created by the functional API, the old entrypoint
            # was a static method -> make sure the new one is as well
            new_entrypoint = staticmethod(new_entrypoint)

        setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
        return _step

    if _step is None:
        return inner_decorator
    else:
        return inner_decorator(_step)
whylabs_entrypoint(dataset_id=None)

Decorator for a step entrypoint to enable Whylabs logging.

Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.

Note that you also need to have a whylogs Data Validator part of your active stack with the Whylabs credentials configured for this to have effect.

All the whylogs data profile views returned by the step will automatically be uploaded to the Whylabs platform if the active whylogs Data Validator component is configured with Whylabs credentials, e.g.:

.. highlight:: python .. code-block:: python

import whylogs as why
from whylogs.core import DatasetProfileView
from zenml.integrations.whylogs.whylabs_step_decorator import whylabs_entrypoint

@step
@whylabs_entrypoint(dataset_id="my_model")
def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
    ...
    data = pd.DataFrame(...)
    results = why.log(pandas=dataset)
    profile = results.profile()
    ...
    return data, profile.view()

Parameters:

Name Type Description Default
dataset_id Optional[str]

Optional dataset ID to use for the uploaded profile(s)

None

Returns:

Type Description
Callable[[~F], ~F]

the input function enhanced with Whylabs profiling functionality

Source code in zenml/integrations/whylogs/whylabs_step_decorator.py
def whylabs_entrypoint(
    dataset_id: Optional[str] = None,
) -> Callable[[F], F]:
    """Decorator for a step entrypoint to enable Whylabs logging.

    Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.

    Note that you also need to have a whylogs Data Validator part of your active
    stack with the Whylabs credentials configured for this to have effect.

    All the whylogs data profile views returned by the step will automatically
    be uploaded to the Whylabs platform if the active whylogs Data Validator
    component is configured with Whylabs credentials, e.g.:

    .. highlight:: python
    .. code-block:: python

        import whylogs as why
        from whylogs.core import DatasetProfileView
        from zenml.integrations.whylogs.whylabs_step_decorator import whylabs_entrypoint

        @step
        @whylabs_entrypoint(dataset_id="my_model")
        def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
            ...
            data = pd.DataFrame(...)
            results = why.log(pandas=dataset)
            profile = results.profile()
            ...
            return data, profile.view()

    Args:
        dataset_id: Optional dataset ID to use for the uploaded profile(s)

    Returns:
        the input function enhanced with Whylabs profiling functionality
    """

    def inner_decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa
            os.environ[WHYLABS_LOGGING_ENABLED_ENV] = "true"
            if dataset_id:
                os.environ[WHYLABS_DATASET_ID_ENV] = dataset_id
            try:
                return func(*args, **kwargs)
            finally:
                del os.environ[WHYLABS_LOGGING_ENABLED_ENV]
                if dataset_id:
                    del os.environ[WHYLABS_DATASET_ID_ENV]

        return cast(F, wrapper)

    return inner_decorator

xgboost special

Initialization of the XGBoost integration.

XgboostIntegration (Integration)

Definition of xgboost integration for ZenML.

Source code in zenml/integrations/xgboost/__init__.py
class XgboostIntegration(Integration):
    """Definition of xgboost integration for ZenML."""

    NAME = XGBOOST
    REQUIREMENTS = ["xgboost>=1.0.0"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.xgboost import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/xgboost/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.xgboost import materializers  # noqa

materializers special

Initialization of the XGBoost materializers.

xgboost_booster_materializer

Implementation of an XGBoost booster materializer.

XgboostBoosterMaterializer (BaseMaterializer)

Materializer to read data to and from xgboost.Booster.

Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
class XgboostBoosterMaterializer(BaseMaterializer):
    """Materializer to read data to and from xgboost.Booster."""

    ASSOCIATED_TYPES = (xgb.Booster,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
        """Reads a xgboost Booster model from a serialized JSON file.

        Args:
            data_type: A xgboost Booster type.

        Returns:
            A xgboost Booster object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        booster = xgb.Booster()
        booster.load_model(temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)
        return booster

    def handle_return(self, booster: xgb.Booster) -> None:
        """Creates a JSON serialization for a xgboost Booster model.

        Args:
            booster: A xgboost Booster model.
        """
        super().handle_return(booster)

        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".json", delete=False
        ) as f:
            booster.save_model(f.name)
            # Copy it into artifact store
            fileio.copy(f.name, filepath)

        # Close and remove the temporary file
        f.close()
        fileio.remove(f.name)
handle_input(self, data_type)

Reads a xgboost Booster model from a serialized JSON file.

Parameters:

Name Type Description Default
data_type Type[Any]

A xgboost Booster type.

required

Returns:

Type Description
Booster

A xgboost Booster object.

Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
    """Reads a xgboost Booster model from a serialized JSON file.

    Args:
        data_type: A xgboost Booster type.

    Returns:
        A xgboost Booster object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    booster = xgb.Booster()
    booster.load_model(temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)
    return booster
handle_return(self, booster)

Creates a JSON serialization for a xgboost Booster model.

Parameters:

Name Type Description Default
booster Booster

A xgboost Booster model.

required
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_return(self, booster: xgb.Booster) -> None:
    """Creates a JSON serialization for a xgboost Booster model.

    Args:
        booster: A xgboost Booster model.
    """
    super().handle_return(booster)

    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".json", delete=False
    ) as f:
        booster.save_model(f.name)
        # Copy it into artifact store
        fileio.copy(f.name, filepath)

    # Close and remove the temporary file
    f.close()
    fileio.remove(f.name)
xgboost_dmatrix_materializer

Implementation of the XGBoost dmatrix materializer.

XgboostDMatrixMaterializer (BaseMaterializer)

Materializer to read data to and from xgboost.DMatrix.

Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
class XgboostDMatrixMaterializer(BaseMaterializer):
    """Materializer to read data to and from xgboost.DMatrix."""

    ASSOCIATED_TYPES = (xgb.DMatrix,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
        """Reads a xgboost.DMatrix binary file and loads it.

        Args:
            data_type: The datatype which should be read.

        Returns:
            Materialized xgboost matrix.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        matrix = xgb.DMatrix(temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)
        return matrix

    def handle_return(self, matrix: xgb.DMatrix) -> None:
        """Creates a binary serialization for a xgboost.DMatrix object.

        Args:
            matrix: A xgboost.DMatrix object.
        """
        super().handle_return(matrix)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
            matrix.save_binary(f.name)
            # Copy it into artifact store
            fileio.copy(f.name, filepath)

        # Close and remove the temporary file
        f.close()
        fileio.remove(f.name)
handle_input(self, data_type)

Reads a xgboost.DMatrix binary file and loads it.

Parameters:

Name Type Description Default
data_type Type[Any]

The datatype which should be read.

required

Returns:

Type Description
DMatrix

Materialized xgboost matrix.

Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
    """Reads a xgboost.DMatrix binary file and loads it.

    Args:
        data_type: The datatype which should be read.

    Returns:
        Materialized xgboost matrix.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    matrix = xgb.DMatrix(temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)
    return matrix
handle_return(self, matrix)

Creates a binary serialization for a xgboost.DMatrix object.

Parameters:

Name Type Description Default
matrix DMatrix

A xgboost.DMatrix object.

required
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_return(self, matrix: xgb.DMatrix) -> None:
    """Creates a binary serialization for a xgboost.DMatrix object.

    Args:
        matrix: A xgboost.DMatrix object.
    """
    super().handle_return(matrix)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
        matrix.save_binary(f.name)
        # Copy it into artifact store
        fileio.copy(f.name, filepath)

    # Close and remove the temporary file
    f.close()
    fileio.remove(f.name)