Skip to content

Integrations

zenml.integrations special

The ZenML integrations module contains sub-modules for each integration that we support. This includes orchestrators like Apache Airflow, visualization tools like the facets library, as well as deep learning libraries like PyTorch.

airflow special

The Airflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Airflow orchestrator with the CLI tool, then bootstrap using the zenml orchestrator up command.

AirflowIntegration (Integration)

Definition of Airflow Integration for ZenML.

Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
    """Definition of Airflow Integration for ZenML."""

    NAME = AIRFLOW
    REQUIREMENTS = ["apache-airflow==2.2.0"]

    @classmethod
    def activate(cls):
        """Activates all classes required for the airflow integration."""
        from zenml.integrations.airflow import orchestrators  # noqa
activate() classmethod

Activates all classes required for the airflow integration.

Source code in zenml/integrations/airflow/__init__.py
@classmethod
def activate(cls):
    """Activates all classes required for the airflow integration."""
    from zenml.integrations.airflow import orchestrators  # noqa

orchestrators special

airflow_component

Definition for Airflow component for TFX.

AirflowComponent (PythonOperator)

Airflow-specific TFX Component. This class wrap a component run into its own PythonOperator in Airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_component.py
class AirflowComponent(python.PythonOperator):
    """Airflow-specific TFX Component.
    This class wrap a component run into its own PythonOperator in Airflow.
    """

    def __init__(
        self,
        *,
        parent_dag: airflow.DAG,
        pipeline_node: pipeline_pb2.PipelineNode,
        mlmd_connection: metadata.Metadata,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec,
        executor_spec: Optional[message.Message] = None,
        custom_driver_spec: Optional[message.Message] = None
    ) -> None:
        """Constructs an Airflow implementation of TFX component.

        Args:
            parent_dag: The airflow DAG that this component is contained in.
            pipeline_node: The specification of the node to launch.
            mlmd_connection: ML metadata connection info.
            pipeline_info: The information of the pipeline that this node
                runs in.
            pipeline_runtime_spec: The runtime information of the pipeline
                that this node runs in.
            executor_spec: Specification for the executor of the node.
            custom_driver_spec: Specification for custom driver.
        """
        launcher_callable = functools.partial(
            _airflow_component_launcher,
            pipeline_node=pipeline_node,
            mlmd_connection=mlmd_connection,
            pipeline_info=pipeline_info,
            pipeline_runtime_spec=pipeline_runtime_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
        )

        super().__init__(
            task_id=pipeline_node.node_info.id,
            provide_context=True,
            python_callable=launcher_callable,
            dag=parent_dag,
        )
__init__(self, *, parent_dag, pipeline_node, mlmd_connection, pipeline_info, pipeline_runtime_spec, executor_spec=None, custom_driver_spec=None) special

Constructs an Airflow implementation of TFX component.

Parameters:

Name Type Description Default
parent_dag DAG

The airflow DAG that this component is contained in.

required
pipeline_node PipelineNode

The specification of the node to launch.

required
mlmd_connection Metadata

ML metadata connection info.

required
pipeline_info PipelineInfo

The information of the pipeline that this node runs in.

required
pipeline_runtime_spec PipelineRuntimeSpec

The runtime information of the pipeline that this node runs in.

required
executor_spec Optional[google.protobuf.message.Message]

Specification for the executor of the node.

None
custom_driver_spec Optional[google.protobuf.message.Message]

Specification for custom driver.

None
Source code in zenml/integrations/airflow/orchestrators/airflow_component.py
def __init__(
    self,
    *,
    parent_dag: airflow.DAG,
    pipeline_node: pipeline_pb2.PipelineNode,
    mlmd_connection: metadata.Metadata,
    pipeline_info: pipeline_pb2.PipelineInfo,
    pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec,
    executor_spec: Optional[message.Message] = None,
    custom_driver_spec: Optional[message.Message] = None
) -> None:
    """Constructs an Airflow implementation of TFX component.

    Args:
        parent_dag: The airflow DAG that this component is contained in.
        pipeline_node: The specification of the node to launch.
        mlmd_connection: ML metadata connection info.
        pipeline_info: The information of the pipeline that this node
            runs in.
        pipeline_runtime_spec: The runtime information of the pipeline
            that this node runs in.
        executor_spec: Specification for the executor of the node.
        custom_driver_spec: Specification for custom driver.
    """
    launcher_callable = functools.partial(
        _airflow_component_launcher,
        pipeline_node=pipeline_node,
        mlmd_connection=mlmd_connection,
        pipeline_info=pipeline_info,
        pipeline_runtime_spec=pipeline_runtime_spec,
        executor_spec=executor_spec,
        custom_driver_spec=custom_driver_spec,
    )

    super().__init__(
        task_id=pipeline_node.node_info.id,
        provide_context=True,
        python_callable=launcher_callable,
        dag=parent_dag,
    )
airflow_dag_runner

Definition of Airflow TFX runner. This is an unmodified copy from the TFX source code (outside of superficial, stylistic changes)

AirflowDagRunner (TfxRunner)

Tfx runner on Airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
class AirflowDagRunner(tfx_runner.TfxRunner):
    """Tfx runner on Airflow."""

    def __init__(
        self,
        config: Optional[Union[Dict[str, Any], AirflowPipelineConfig]] = None,
    ):
        """Creates an instance of AirflowDagRunner.

        Args:
          config: Optional Airflow pipeline config for customizing the
          launching of each component.
        """
        if isinstance(config, dict):
            warnings.warn(
                "Pass config as a dict type is going to deprecated in 0.1.16. "
                "Use AirflowPipelineConfig type instead.",
                PendingDeprecationWarning,
            )
            config = AirflowPipelineConfig(airflow_dag_config=config)
        super().__init__(config)

    def run(
        self, pipeline: tfx_pipeline.Pipeline, run_name: str = ""
    ) -> "airflow.DAG":
        """Deploys given logical pipeline on Airflow.

        Args:
          pipeline: Logical pipeline containing pipeline args and comps.
          run_name: Optional name for the run.

        Returns:
          An Airflow DAG.
        """
        # Only import these when needed.
        import airflow  # noqa

        from zenml.integrations.airflow.orchestrators import airflow_component

        # Merge airflow-specific configs with pipeline args

        airflow_dag = airflow.DAG(
            dag_id=pipeline.pipeline_info.pipeline_name,
            **(
                typing.cast(
                    AirflowPipelineConfig, self._config
                ).airflow_dag_config
            ),
            is_paused_upon_creation=False,
            catchup=False,  # no backfill
        )
        if "tmp_dir" not in pipeline.additional_pipeline_args:
            tmp_dir = os.path.join(
                pipeline.pipeline_info.pipeline_root, ".temp", ""
            )
            pipeline.additional_pipeline_args["tmp_dir"] = tmp_dir

        for component in pipeline.components:
            if isinstance(component, base_component.BaseComponent):
                component._resolve_pip_dependencies(
                    pipeline.pipeline_info.pipeline_root
                )
            self._replace_runtime_params(component)

        c = compiler.Compiler()
        pipeline = c.compile(pipeline)

        # Substitute the runtime parameter to be a concrete run_id
        runtime_parameter_utils.substitute_runtime_parameter(
            pipeline,
            {
                "pipeline-run-id": run_name,
            },
        )
        deployment_config = runner_utils.extract_local_deployment_config(
            pipeline
        )
        connection_config = deployment_config.metadata_connection_config  # type: ignore[attr-defined] # noqa

        component_impl_map = {}

        for node in pipeline.nodes:
            pipeline_node = node.pipeline_node
            node_id = pipeline_node.node_info.id
            executor_spec = runner_utils.extract_executor_spec(
                deployment_config, node_id
            )
            custom_driver_spec = runner_utils.extract_custom_driver_spec(
                deployment_config, node_id
            )

            current_airflow_component = airflow_component.AirflowComponent(
                parent_dag=airflow_dag,
                pipeline_node=pipeline_node,
                mlmd_connection=connection_config,
                pipeline_info=pipeline.pipeline_info,
                pipeline_runtime_spec=pipeline.runtime_spec,
                executor_spec=executor_spec,
                custom_driver_spec=custom_driver_spec,
            )
            component_impl_map[node_id] = current_airflow_component
            for upstream_node in node.pipeline_node.upstream_nodes:
                assert (
                    upstream_node in component_impl_map
                ), "Components is not in topological order"
                current_airflow_component.set_upstream(
                    component_impl_map[upstream_node]
                )

        return airflow_dag

    def _replace_runtime_params(
        self, comp: base_node.BaseNode
    ) -> base_node.BaseNode:
        """Replaces runtime params for dynamic Airflow parameter execution.

        Args:
            comp: TFX component to be parsed.

        Returns:
            Returns edited component.
        """
        for k, prop in comp.exec_properties.copy().items():
            if isinstance(prop, RuntimeParameter):
                # Airflow only supports string parameters.
                if prop.ptype != str:
                    raise RuntimeError(
                        f"RuntimeParameter in Airflow does not support "
                        f"{prop.ptype}. The only ptype supported is string."
                    )

                # If the default is a template, drop the template markers
                # when inserting it into the .get() default argument below.
                # Otherwise, provide the default as a quoted string.
                default = cast(str, prop.default)
                if default.startswith("{{") and default.endswith("}}"):
                    default = default[2:-2]
                else:
                    default = json.dumps(default)

                template_field = '{{ dag_run.conf.get("%s", %s) }}' % (
                    prop.name,
                    default,
                )
                comp.exec_properties[k] = template_field
        return comp
__init__(self, config=None) special

Creates an instance of AirflowDagRunner.

Parameters:

Name Type Description Default
config Union[Dict[str, Any], zenml.integrations.airflow.orchestrators.airflow_dag_runner.AirflowPipelineConfig]

Optional Airflow pipeline config for customizing the

None
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
def __init__(
    self,
    config: Optional[Union[Dict[str, Any], AirflowPipelineConfig]] = None,
):
    """Creates an instance of AirflowDagRunner.

    Args:
      config: Optional Airflow pipeline config for customizing the
      launching of each component.
    """
    if isinstance(config, dict):
        warnings.warn(
            "Pass config as a dict type is going to deprecated in 0.1.16. "
            "Use AirflowPipelineConfig type instead.",
            PendingDeprecationWarning,
        )
        config = AirflowPipelineConfig(airflow_dag_config=config)
    super().__init__(config)
run(self, pipeline, run_name='')

Deploys given logical pipeline on Airflow.

Parameters:

Name Type Description Default
pipeline Pipeline

Logical pipeline containing pipeline args and comps.

required
run_name str

Optional name for the run.

''

Returns:

Type Description
airflow.DAG

An Airflow DAG.

Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
def run(
    self, pipeline: tfx_pipeline.Pipeline, run_name: str = ""
) -> "airflow.DAG":
    """Deploys given logical pipeline on Airflow.

    Args:
      pipeline: Logical pipeline containing pipeline args and comps.
      run_name: Optional name for the run.

    Returns:
      An Airflow DAG.
    """
    # Only import these when needed.
    import airflow  # noqa

    from zenml.integrations.airflow.orchestrators import airflow_component

    # Merge airflow-specific configs with pipeline args

    airflow_dag = airflow.DAG(
        dag_id=pipeline.pipeline_info.pipeline_name,
        **(
            typing.cast(
                AirflowPipelineConfig, self._config
            ).airflow_dag_config
        ),
        is_paused_upon_creation=False,
        catchup=False,  # no backfill
    )
    if "tmp_dir" not in pipeline.additional_pipeline_args:
        tmp_dir = os.path.join(
            pipeline.pipeline_info.pipeline_root, ".temp", ""
        )
        pipeline.additional_pipeline_args["tmp_dir"] = tmp_dir

    for component in pipeline.components:
        if isinstance(component, base_component.BaseComponent):
            component._resolve_pip_dependencies(
                pipeline.pipeline_info.pipeline_root
            )
        self._replace_runtime_params(component)

    c = compiler.Compiler()
    pipeline = c.compile(pipeline)

    # Substitute the runtime parameter to be a concrete run_id
    runtime_parameter_utils.substitute_runtime_parameter(
        pipeline,
        {
            "pipeline-run-id": run_name,
        },
    )
    deployment_config = runner_utils.extract_local_deployment_config(
        pipeline
    )
    connection_config = deployment_config.metadata_connection_config  # type: ignore[attr-defined] # noqa

    component_impl_map = {}

    for node in pipeline.nodes:
        pipeline_node = node.pipeline_node
        node_id = pipeline_node.node_info.id
        executor_spec = runner_utils.extract_executor_spec(
            deployment_config, node_id
        )
        custom_driver_spec = runner_utils.extract_custom_driver_spec(
            deployment_config, node_id
        )

        current_airflow_component = airflow_component.AirflowComponent(
            parent_dag=airflow_dag,
            pipeline_node=pipeline_node,
            mlmd_connection=connection_config,
            pipeline_info=pipeline.pipeline_info,
            pipeline_runtime_spec=pipeline.runtime_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
        )
        component_impl_map[node_id] = current_airflow_component
        for upstream_node in node.pipeline_node.upstream_nodes:
            assert (
                upstream_node in component_impl_map
            ), "Components is not in topological order"
            current_airflow_component.set_upstream(
                component_impl_map[upstream_node]
            )

    return airflow_dag
AirflowPipelineConfig (PipelineConfig)

Pipeline config for AirflowDagRunner.

Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
class AirflowPipelineConfig(pipeline_config.PipelineConfig):
    """Pipeline config for AirflowDagRunner."""

    def __init__(
        self, airflow_dag_config: Optional[Dict[str, Any]] = None, **kwargs: Any
    ):
        """Creates an instance of AirflowPipelineConfig.

        Args:
          airflow_dag_config: Configs of Airflow DAG model. See
            https://airflow.apache.org/_api/airflow/models/dag/index.html#airflow.models.dag.DAG
            for the full spec.
          **kwargs: keyword args for PipelineConfig.
        """

        super().__init__(**kwargs)
        self.airflow_dag_config = airflow_dag_config or {}
__init__(self, airflow_dag_config=None, **kwargs) special

Creates an instance of AirflowPipelineConfig.

Parameters:

Name Type Description Default
airflow_dag_config Optional[Dict[str, Any]]

Configs of Airflow DAG model. See https://airflow.apache.org/_api/airflow/models/dag/index.html#airflow.models.dag.DAG for the full spec.

None
**kwargs Any

keyword args for PipelineConfig.

{}
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
def __init__(
    self, airflow_dag_config: Optional[Dict[str, Any]] = None, **kwargs: Any
):
    """Creates an instance of AirflowPipelineConfig.

    Args:
      airflow_dag_config: Configs of Airflow DAG model. See
        https://airflow.apache.org/_api/airflow/models/dag/index.html#airflow.models.dag.DAG
        for the full spec.
      **kwargs: keyword args for PipelineConfig.
    """

    super().__init__(**kwargs)
    self.airflow_dag_config = airflow_dag_config or {}
airflow_orchestrator
AirflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Airflow."""

    airflow_home: str = ""
    airflow_config: Optional[Dict[str, Any]] = {}
    schedule_interval_minutes: int = 1

    def __init__(self, **values: Any):
        """Sets environment variables to configure airflow."""
        super().__init__(**values)
        self._set_env()

    @root_validator
    def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Sets airflow home according to orchestrator UUID."""
        if "uuid" not in values:
            raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
        values["airflow_home"] = os.path.join(
            zenml.io.utils.get_global_config_directory(),
            AIRFLOW_ROOT_DIR,
            str(values["uuid"]),
        )
        return values

    @property
    def dags_directory(self) -> str:
        """Returns path to the airflow dags directory."""
        return os.path.join(self.airflow_home, "dags")

    @property
    def pid_file(self) -> str:
        """Returns path to the daemon PID file."""
        return os.path.join(self.airflow_home, "airflow_daemon.pid")

    @property
    def log_file(self) -> str:
        """Returns path to the airflow log file."""
        return os.path.join(self.airflow_home, "airflow_orchestrator.log")

    @property
    def password_file(self) -> str:
        """Returns path to the webserver password file."""
        return os.path.join(self.airflow_home, "standalone_admin_password.txt")

    @property
    def is_running(self) -> bool:
        """Returns whether the airflow daemon is currently running."""
        from airflow.cli.commands.standalone_command import StandaloneCommand
        from airflow.jobs.triggerer_job import TriggererJob

        daemon_running = daemon.check_if_daemon_is_running(self.pid_file)

        command = StandaloneCommand()
        webserver_port_open = command.port_open(8080)

        if not daemon_running:
            if webserver_port_open:
                raise RuntimeError(
                    "The airflow daemon does not seem to be running but "
                    "local port 8080 is occupied. Make sure the port is "
                    "available and try again."
                )

            # exit early so we don't check non-existing airflow databases
            return False

        # we can't use StandaloneCommand().is_ready() here as the
        # Airflow SequentialExecutor apparently does not send a heartbeat
        # while running a task which would result in this returning `False`
        # even if Airflow is running.
        airflow_running = webserver_port_open and command.job_running(
            TriggererJob
        )
        return airflow_running

    def _set_env(self) -> None:
        """Sets environment variables to configure airflow."""
        os.environ["AIRFLOW_HOME"] = self.airflow_home
        os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
        os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
        os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
        # check the DAG folder every 10 seconds for new files
        os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"

    def _copy_to_dag_directory_if_necessary(self, dag_filepath: str):
        """Copies the DAG module to the airflow DAGs directory if it's not
        already located there.

        Args:
            dag_filepath: Path to the file in which the DAG is defined.
        """
        dags_directory = fileio.resolve_relative_path(self.dags_directory)

        if dags_directory == os.path.dirname(dag_filepath):
            logger.debug("File is already in airflow DAGs directory.")
        else:
            logger.debug(
                "Copying dag file '%s' to DAGs directory.", dag_filepath
            )
            destination_path = os.path.join(
                dags_directory, os.path.basename(dag_filepath)
            )
            if fileio.file_exists(destination_path):
                logger.info(
                    "File '%s' already exists, overwriting with new DAG file",
                    destination_path,
                )
            fileio.copy(dag_filepath, destination_path, overwrite=True)

    def _log_webserver_credentials(self):
        """Logs URL and credentials to login to the airflow webserver.

        Raises:
            FileNotFoundError: If the password file does not exist.
        """
        if fileio.file_exists(self.password_file):
            with open(self.password_file) as file:
                password = file.read().strip()
        else:
            raise FileNotFoundError(
                f"Can't find password file '{self.password_file}'"
            )
        logger.info(
            "To inspect your DAGs, login to http://0.0.0.0:8080 "
            "with username: admin password: %s",
            password,
        )

    def pre_run(self, pipeline: "BasePipeline", caller_filepath: str) -> None:
        """Checks whether airflow is running and copies the DAG file to the
        airflow DAGs directory.

        Args:
            pipeline: Pipeline that will be run.
            caller_filepath: Path to the file in which `pipeline.run()` was
                called. This contains the airflow DAG that is returned by
                the `run()` method.

        Raises:
            RuntimeError: If airflow is not running.
        """
        if not self.is_running:
            raise RuntimeError(
                "Airflow orchestrator is currently not running. "
                "Run `zenml orchestrator up` to start the "
                "orchestrator of the active stack."
            )

        self._copy_to_dag_directory_if_necessary(dag_filepath=caller_filepath)

    def up(self) -> None:
        """Ensures that Airflow is running."""
        if self.is_running:
            logger.info("Airflow is already running.")
            self._log_webserver_credentials()
            return

        if not fileio.file_exists(self.dags_directory):
            fileio.create_dir_recursive_if_not_exists(self.dags_directory)

        from airflow.cli.commands.standalone_command import StandaloneCommand

        command = StandaloneCommand()
        # Run the daemon with a working directory inside the current
        # zenml repo so the same repo will be used to run the DAGs
        daemon.run_as_daemon(
            command.run,
            pid_file=self.pid_file,
            log_file=self.log_file,
            working_directory=zenml.io.utils.get_zenml_dir(),
        )

        while not self.is_running:
            # Wait until the daemon started all the relevant airflow processes
            time.sleep(0.1)

        self._log_webserver_credentials()

    def down(self) -> None:
        """Stops the airflow daemon if necessary and tears down resources."""
        if self.is_running:
            daemon.stop_daemon(self.pid_file, kill_children=True)

        fileio.rm_dir(self.airflow_home)
        logger.info("Airflow spun down.")

    def run(
        self,
        zenml_pipeline: "BasePipeline",
        run_name: str,
        **kwargs: Any,
    ) -> "airflow.DAG":
        """Prepares the pipeline so it can be run in Airflow.

        Args:
            zenml_pipeline: The pipeline to run.
            run_name: Name of the pipeline run.
            **kwargs: Unused argument to conform with base class signature.
        """
        self.airflow_config = {
            "schedule_interval": datetime.timedelta(
                minutes=self.schedule_interval_minutes
            ),
            # We set this in the past and turn catchup off and then it works
            "start_date": datetime.datetime(2019, 1, 1),
        }

        runner = AirflowDagRunner(AirflowPipelineConfig(self.airflow_config))
        tfx_pipeline = create_tfx_pipeline(zenml_pipeline)
        return runner.run(tfx_pipeline, run_name=run_name)
dags_directory: str property readonly

Returns path to the airflow dags directory.

is_running: bool property readonly

Returns whether the airflow daemon is currently running.

log_file: str property readonly

Returns path to the airflow log file.

password_file: str property readonly

Returns path to the webserver password file.

pid_file: str property readonly

Returns path to the daemon PID file.

__init__(self, **values) special

Sets environment variables to configure airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
    """Sets environment variables to configure airflow."""
    super().__init__(**values)
    self._set_env()
down(self)

Stops the airflow daemon if necessary and tears down resources.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def down(self) -> None:
    """Stops the airflow daemon if necessary and tears down resources."""
    if self.is_running:
        daemon.stop_daemon(self.pid_file, kill_children=True)

    fileio.rm_dir(self.airflow_home)
    logger.info("Airflow spun down.")
pre_run(self, pipeline, caller_filepath)

Checks whether airflow is running and copies the DAG file to the airflow DAGs directory.

Parameters:

Name Type Description Default
pipeline BasePipeline

Pipeline that will be run.

required
caller_filepath str

Path to the file in which pipeline.run() was called. This contains the airflow DAG that is returned by the run() method.

required

Exceptions:

Type Description
RuntimeError

If airflow is not running.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def pre_run(self, pipeline: "BasePipeline", caller_filepath: str) -> None:
    """Checks whether airflow is running and copies the DAG file to the
    airflow DAGs directory.

    Args:
        pipeline: Pipeline that will be run.
        caller_filepath: Path to the file in which `pipeline.run()` was
            called. This contains the airflow DAG that is returned by
            the `run()` method.

    Raises:
        RuntimeError: If airflow is not running.
    """
    if not self.is_running:
        raise RuntimeError(
            "Airflow orchestrator is currently not running. "
            "Run `zenml orchestrator up` to start the "
            "orchestrator of the active stack."
        )

    self._copy_to_dag_directory_if_necessary(dag_filepath=caller_filepath)
run(self, zenml_pipeline, run_name, **kwargs)

Prepares the pipeline so it can be run in Airflow.

Parameters:

Name Type Description Default
zenml_pipeline BasePipeline

The pipeline to run.

required
run_name str

Name of the pipeline run.

required
**kwargs Any

Unused argument to conform with base class signature.

{}
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def run(
    self,
    zenml_pipeline: "BasePipeline",
    run_name: str,
    **kwargs: Any,
) -> "airflow.DAG":
    """Prepares the pipeline so it can be run in Airflow.

    Args:
        zenml_pipeline: The pipeline to run.
        run_name: Name of the pipeline run.
        **kwargs: Unused argument to conform with base class signature.
    """
    self.airflow_config = {
        "schedule_interval": datetime.timedelta(
            minutes=self.schedule_interval_minutes
        ),
        # We set this in the past and turn catchup off and then it works
        "start_date": datetime.datetime(2019, 1, 1),
    }

    runner = AirflowDagRunner(AirflowPipelineConfig(self.airflow_config))
    tfx_pipeline = create_tfx_pipeline(zenml_pipeline)
    return runner.run(tfx_pipeline, run_name=run_name)
set_airflow_home(values) classmethod

Sets airflow home according to orchestrator UUID.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Sets airflow home according to orchestrator UUID."""
    if "uuid" not in values:
        raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
    values["airflow_home"] = os.path.join(
        zenml.io.utils.get_global_config_directory(),
        AIRFLOW_ROOT_DIR,
        str(values["uuid"]),
    )
    return values
up(self)

Ensures that Airflow is running.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def up(self) -> None:
    """Ensures that Airflow is running."""
    if self.is_running:
        logger.info("Airflow is already running.")
        self._log_webserver_credentials()
        return

    if not fileio.file_exists(self.dags_directory):
        fileio.create_dir_recursive_if_not_exists(self.dags_directory)

    from airflow.cli.commands.standalone_command import StandaloneCommand

    command = StandaloneCommand()
    # Run the daemon with a working directory inside the current
    # zenml repo so the same repo will be used to run the DAGs
    daemon.run_as_daemon(
        command.run,
        pid_file=self.pid_file,
        log_file=self.log_file,
        working_directory=zenml.io.utils.get_zenml_dir(),
    )

    while not self.is_running:
        # Wait until the daemon started all the relevant airflow processes
        time.sleep(0.1)

    self._log_webserver_credentials()

dash special

DashIntegration (Integration)

Definition of Dash integration for ZenML.

Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
    """Definition of Dash integration for ZenML."""

    NAME = DASH
    REQUIREMENTS = [
        "dash>=2.0.0",
        "dash-cytoscape>=0.3.0",
        "dash-bootstrap-components>=1.0.1",
    ]

visualizers special

pipeline_run_lineage_visualizer
PipelineRunLineageVisualizer (BasePipelineRunVisualizer)

Implementation of a lineage diagram via the dash and dash-cyctoscape library.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BasePipelineRunVisualizer):
    """Implementation of a lineage diagram via the [dash](
    https://plotly.com/dash/) and [dash-cyctoscape](
    https://dash.plotly.com/cytoscape) library."""

    ARTIFACT_PREFIX = "artifact_"
    STEP_PREFIX = "step_"
    STATUS_CLASS_MAPPING = {
        ExecutionStatus.CACHED: "green",
        ExecutionStatus.FAILED: "red",
        ExecutionStatus.RUNNING: "yellow",
        ExecutionStatus.COMPLETED: "blue",
    }

    def visualize(
        self, object: PipelineRunView, *args: Any, **kwargs: Any
    ) -> dash.Dash:
        """Method to visualize pipeline runs via the Dash library. The layout
        puts every layer of the dag in a column.
        """

        app = dash.Dash(
            __name__,
            external_stylesheets=[
                dbc.themes.BOOTSTRAP,
                dbc.icons.BOOTSTRAP,
            ],
        )
        nodes, edges, first_step_id = [], [], None
        first_step_id = None
        for step in object.steps:
            step_output_artifacts = list(step.outputs.values())
            execution_id = (
                step_output_artifacts[0].producer_step.id
                if step_output_artifacts
                else step.id
            )
            step_id = self.STEP_PREFIX + str(step.id)
            if first_step_id is None:
                first_step_id = step_id
            nodes.append(
                {
                    "data": {
                        "id": step_id,
                        "execution_id": execution_id,
                        "label": f"{execution_id} / {step.name}",
                        "name": step.name,  # redundant for consistency
                        "type": "step",
                        "parameters": step.parameters,
                        "inputs": {k: v.uri for k, v in step.inputs.items()},
                        "outputs": {k: v.uri for k, v in step.outputs.items()},
                    },
                    "classes": self.STATUS_CLASS_MAPPING[step.status],
                }
            )

            for artifact_name, artifact in step.outputs.items():
                nodes.append(
                    {
                        "data": {
                            "id": self.ARTIFACT_PREFIX + str(artifact.id),
                            "execution_id": artifact.id,
                            "label": f"{artifact.id} / {artifact_name} ("
                            f"{artifact.data_type})",
                            "type": "artifact",
                            "name": artifact_name,
                            "is_cached": artifact.is_cached,
                            "artifact_type": artifact.type,
                            "artifact_data_type": artifact.data_type,
                            "parent_step_id": artifact.parent_step_id,
                            "producer_step_id": artifact.producer_step.id,
                            "uri": artifact.uri,
                        },
                        "classes": f"rectangle "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}",
                    }
                )
                edges.append(
                    {
                        "data": {
                            "source": self.STEP_PREFIX + str(step.id),
                            "target": self.ARTIFACT_PREFIX + str(artifact.id),
                        },
                        "classes": f"edge-arrow "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}"
                        + (" dashed" if artifact.is_cached else " solid"),
                    }
                )

            for artifact_name, artifact in step.inputs.items():
                edges.append(
                    {
                        "data": {
                            "source": self.ARTIFACT_PREFIX + str(artifact.id),
                            "target": self.STEP_PREFIX + str(step.id),
                        },
                        "classes": "edge-arrow "
                        + (
                            f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                            if artifact.is_cached
                            else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                        ),
                    }
                )

        app.layout = dbc.Row(
            [
                dbc.Container(f"Run: {object.name}", class_name="h1"),
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Row(
                                    [
                                        html.Span(
                                            [
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-circle-fill me-1"
                                                        ),
                                                        "Step",
                                                    ],
                                                    className="me-2",
                                                ),
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-square-fill me-1"
                                                        ),
                                                        "Artifact",
                                                    ],
                                                    className="me-4",
                                                ),
                                                dbc.Badge(
                                                    "Completed",
                                                    color=COLOR_BLUE,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Cached",
                                                    color=COLOR_GREEN,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Running",
                                                    color=COLOR_YELLOW,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Failed",
                                                    color=COLOR_RED,
                                                    className="me-1",
                                                ),
                                            ]
                                        ),
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        cyto.Cytoscape(
                                            id="cytoscape",
                                            layout={
                                                "name": "breadthfirst",
                                                "roots": f'[id = "{first_step_id}"]',
                                            },
                                            elements=edges + nodes,
                                            stylesheet=STYLESHEET,
                                            style={
                                                "width": "100%",
                                                "height": "800px",
                                            },
                                            zoom=1,
                                        )
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        dbc.Button(
                                            "Reset",
                                            id="bt-reset",
                                            color="primary",
                                            className="me-1",
                                        )
                                    ]
                                ),
                            ]
                        ),
                        dbc.Col(
                            [
                                dcc.Markdown(id="markdown-selected-node-data"),
                            ]
                        ),
                    ]
                ),
            ],
            className="p-5",
        )

        @app.callback(  # type: ignore[misc]
            Output("markdown-selected-node-data", "children"),
            Input("cytoscape", "selectedNodeData"),
        )
        def display_data(data_list: List[Dict[str, Any]]) -> str:
            """Callback for the text area below the graph"""
            if data_list is None:
                return "Click on a node in the diagram."

            text = ""
            for data in data_list:
                text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
                if data["type"] == "artifact":
                    for item in [
                        "artifact_data_type",
                        "is_cached",
                        "producer_step_id",
                        "parent_step_id",
                        "uri",
                    ]:
                        text += f"**{item}**: {data[item]}" + "\n\n"
                elif data["type"] == "step":
                    text += "### Inputs:" + "\n\n"
                    for k, v in data["inputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Outputs:" + "\n\n"
                    for k, v in data["outputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Params:"
                    for k, v in data["parameters"].items():
                        text += f"**{k}**: {v}" + "\n\n"
            return text

        @app.callback(  # type: ignore[misc]
            [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
            [Input("bt-reset", "n_clicks")],
        )
        def reset_layout(
            n_clicks: int,
        ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
            """Resets the layout"""
            logger.debug(n_clicks, "clicked in reset button.")
            return [1, edges + nodes]

        app.run_server()
        return app
visualize(self, object, *args, **kwargs)

Method to visualize pipeline runs via the Dash library. The layout puts every layer of the dag in a column.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
    self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> dash.Dash:
    """Method to visualize pipeline runs via the Dash library. The layout
    puts every layer of the dag in a column.
    """

    app = dash.Dash(
        __name__,
        external_stylesheets=[
            dbc.themes.BOOTSTRAP,
            dbc.icons.BOOTSTRAP,
        ],
    )
    nodes, edges, first_step_id = [], [], None
    first_step_id = None
    for step in object.steps:
        step_output_artifacts = list(step.outputs.values())
        execution_id = (
            step_output_artifacts[0].producer_step.id
            if step_output_artifacts
            else step.id
        )
        step_id = self.STEP_PREFIX + str(step.id)
        if first_step_id is None:
            first_step_id = step_id
        nodes.append(
            {
                "data": {
                    "id": step_id,
                    "execution_id": execution_id,
                    "label": f"{execution_id} / {step.name}",
                    "name": step.name,  # redundant for consistency
                    "type": "step",
                    "parameters": step.parameters,
                    "inputs": {k: v.uri for k, v in step.inputs.items()},
                    "outputs": {k: v.uri for k, v in step.outputs.items()},
                },
                "classes": self.STATUS_CLASS_MAPPING[step.status],
            }
        )

        for artifact_name, artifact in step.outputs.items():
            nodes.append(
                {
                    "data": {
                        "id": self.ARTIFACT_PREFIX + str(artifact.id),
                        "execution_id": artifact.id,
                        "label": f"{artifact.id} / {artifact_name} ("
                        f"{artifact.data_type})",
                        "type": "artifact",
                        "name": artifact_name,
                        "is_cached": artifact.is_cached,
                        "artifact_type": artifact.type,
                        "artifact_data_type": artifact.data_type,
                        "parent_step_id": artifact.parent_step_id,
                        "producer_step_id": artifact.producer_step.id,
                        "uri": artifact.uri,
                    },
                    "classes": f"rectangle "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}",
                }
            )
            edges.append(
                {
                    "data": {
                        "source": self.STEP_PREFIX + str(step.id),
                        "target": self.ARTIFACT_PREFIX + str(artifact.id),
                    },
                    "classes": f"edge-arrow "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}"
                    + (" dashed" if artifact.is_cached else " solid"),
                }
            )

        for artifact_name, artifact in step.inputs.items():
            edges.append(
                {
                    "data": {
                        "source": self.ARTIFACT_PREFIX + str(artifact.id),
                        "target": self.STEP_PREFIX + str(step.id),
                    },
                    "classes": "edge-arrow "
                    + (
                        f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                        if artifact.is_cached
                        else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                    ),
                }
            )

    app.layout = dbc.Row(
        [
            dbc.Container(f"Run: {object.name}", class_name="h1"),
            dbc.Row(
                [
                    dbc.Col(
                        [
                            dbc.Row(
                                [
                                    html.Span(
                                        [
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-circle-fill me-1"
                                                    ),
                                                    "Step",
                                                ],
                                                className="me-2",
                                            ),
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-square-fill me-1"
                                                    ),
                                                    "Artifact",
                                                ],
                                                className="me-4",
                                            ),
                                            dbc.Badge(
                                                "Completed",
                                                color=COLOR_BLUE,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Cached",
                                                color=COLOR_GREEN,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Running",
                                                color=COLOR_YELLOW,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Failed",
                                                color=COLOR_RED,
                                                className="me-1",
                                            ),
                                        ]
                                    ),
                                ]
                            ),
                            dbc.Row(
                                [
                                    cyto.Cytoscape(
                                        id="cytoscape",
                                        layout={
                                            "name": "breadthfirst",
                                            "roots": f'[id = "{first_step_id}"]',
                                        },
                                        elements=edges + nodes,
                                        stylesheet=STYLESHEET,
                                        style={
                                            "width": "100%",
                                            "height": "800px",
                                        },
                                        zoom=1,
                                    )
                                ]
                            ),
                            dbc.Row(
                                [
                                    dbc.Button(
                                        "Reset",
                                        id="bt-reset",
                                        color="primary",
                                        className="me-1",
                                    )
                                ]
                            ),
                        ]
                    ),
                    dbc.Col(
                        [
                            dcc.Markdown(id="markdown-selected-node-data"),
                        ]
                    ),
                ]
            ),
        ],
        className="p-5",
    )

    @app.callback(  # type: ignore[misc]
        Output("markdown-selected-node-data", "children"),
        Input("cytoscape", "selectedNodeData"),
    )
    def display_data(data_list: List[Dict[str, Any]]) -> str:
        """Callback for the text area below the graph"""
        if data_list is None:
            return "Click on a node in the diagram."

        text = ""
        for data in data_list:
            text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
            if data["type"] == "artifact":
                for item in [
                    "artifact_data_type",
                    "is_cached",
                    "producer_step_id",
                    "parent_step_id",
                    "uri",
                ]:
                    text += f"**{item}**: {data[item]}" + "\n\n"
            elif data["type"] == "step":
                text += "### Inputs:" + "\n\n"
                for k, v in data["inputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Outputs:" + "\n\n"
                for k, v in data["outputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Params:"
                for k, v in data["parameters"].items():
                    text += f"**{k}**: {v}" + "\n\n"
        return text

    @app.callback(  # type: ignore[misc]
        [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
        [Input("bt-reset", "n_clicks")],
    )
    def reset_layout(
        n_clicks: int,
    ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
        """Resets the layout"""
        logger.debug(n_clicks, "clicked in reset button.")
        return [1, edges + nodes]

    app.run_server()
    return app

facets special

The Facets integration provides a simple way to visualize post-execution objects like PipelineView, PipelineRunView and StepView. These objects can be extended using the BaseVisualization class. This integration requires facets-overview be installed in your Python environment.

FacetsIntegration (Integration)

Definition of Facet integration for ZenML.

Source code in zenml/integrations/facets/__init__.py
class FacetsIntegration(Integration):
    """Definition of [Facet](https://pair-code.github.io/facets/) integration
    for ZenML."""

    NAME = FACETS
    REQUIREMENTS = ["facets-overview>=1.0.0", "IPython"]

visualizers special

facet_statistics_visualizer
FacetStatisticsVisualizer (BaseStepVisualizer)

The base implementation of a ZenML Visualizer.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
class FacetStatisticsVisualizer(BaseStepVisualizer):
    """The base implementation of a ZenML Visualizer."""

    @abstractmethod
    def visualize(
        self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
    ) -> None:
        """Method to visualize components

        Args:
            object: StepView fetched from run.get_step().
            magic: Whether to render in a Jupyter notebook or not.
        """
        datasets = []
        for output_name, artifact_view in object.outputs.items():
            df = artifact_view.read()
            if type(df) is not pd.DataFrame:
                logger.warning(
                    "`%s` is not a pd.DataFrame. You can only visualize "
                    "statistics of steps that output pandas dataframes. "
                    "Skipping this output.." % output_name
                )
            else:
                datasets.append({"name": output_name, "table": df})
        h = self.generate_html(datasets)
        self.generate_facet(h, magic)

    def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
        """Generates html for facet.

        Args:
            datasets: List of dicts of dataframes to be visualized as stats.

        Returns:
            HTML template with proto string embedded.
        """
        proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
            datasets
        )
        protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")

        template = os.path.join(
            os.path.abspath(os.path.dirname(__file__)),
            "stats.html",
        )
        html_template = zenml.io.utils.read_file_contents_as_string(template)

        html_ = html_template.replace("protostr", protostr)
        return html_

    def generate_facet(self, html_: str, magic: bool = False) -> None:
        """Generate a Facet Overview

        Args:
            h: HTML represented as a string.
            magic: Whether to magically materialize facet in a notebook.
        """
        if magic:
            if "ipykernel" not in sys.modules:
                raise EnvironmentError(
                    "The magic functions are only usable in a Jupyter notebook."
                )
            display(HTML(html_))
        else:
            with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
                zenml.io.utils.write_file_contents_as_string(f.name, html_)
                url = f"file:///{f.name}"
                logger.info("Opening %s in a new browser.." % f.name)
                webbrowser.open(url, new=2)
generate_facet(self, html_, magic=False)

Generate a Facet Overview

Parameters:

Name Type Description Default
h

HTML represented as a string.

required
magic bool

Whether to magically materialize facet in a notebook.

False
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_facet(self, html_: str, magic: bool = False) -> None:
    """Generate a Facet Overview

    Args:
        h: HTML represented as a string.
        magic: Whether to magically materialize facet in a notebook.
    """
    if magic:
        if "ipykernel" not in sys.modules:
            raise EnvironmentError(
                "The magic functions are only usable in a Jupyter notebook."
            )
        display(HTML(html_))
    else:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
            zenml.io.utils.write_file_contents_as_string(f.name, html_)
            url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
generate_html(self, datasets)

Generates html for facet.

Parameters:

Name Type Description Default
datasets List[Dict[str, pandas.core.frame.DataFrame]]

List of dicts of dataframes to be visualized as stats.

required

Returns:

Type Description
str

HTML template with proto string embedded.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
    """Generates html for facet.

    Args:
        datasets: List of dicts of dataframes to be visualized as stats.

    Returns:
        HTML template with proto string embedded.
    """
    proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
        datasets
    )
    protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")

    template = os.path.join(
        os.path.abspath(os.path.dirname(__file__)),
        "stats.html",
    )
    html_template = zenml.io.utils.read_file_contents_as_string(template)

    html_ = html_template.replace("protostr", protostr)
    return html_
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize components

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
magic bool

Whether to render in a Jupyter notebook or not.

False
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
@abstractmethod
def visualize(
    self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
    """Method to visualize components

    Args:
        object: StepView fetched from run.get_step().
        magic: Whether to render in a Jupyter notebook or not.
    """
    datasets = []
    for output_name, artifact_view in object.outputs.items():
        df = artifact_view.read()
        if type(df) is not pd.DataFrame:
            logger.warning(
                "`%s` is not a pd.DataFrame. You can only visualize "
                "statistics of steps that output pandas dataframes. "
                "Skipping this output.." % output_name
            )
        else:
            datasets.append({"name": output_name, "table": df})
    h = self.generate_html(datasets)
    self.generate_facet(h, magic)

gcp special

The GCP integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores, metadata stores, and an io module to handle file operations on Google Cloud Storage (GCS).

GcpIntegration (Integration)

Definition of Google Cloud Platform integration for ZenML.

Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
    """Definition of Google Cloud Platform integration for ZenML."""

    NAME = GCP
    REQUIREMENTS = ["gcsfs"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.gcp import artifact_stores  # noqa
        from zenml.integrations.gcp import io  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/gcp/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.gcp import artifact_stores  # noqa
    from zenml.integrations.gcp import io  # noqa

artifact_stores special

gcp_artifact_store
GCPArtifactStore (BaseArtifactStore) pydantic-model

Artifact Store for Google Cloud Storage based artifacts.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore):
    """Artifact Store for Google Cloud Storage based artifacts."""

    @validator("path")
    def must_be_gcs_path(cls, v: str) -> str:
        """Validates that the path is a valid gcs path."""
        if not v.startswith("gs://"):
            raise ValueError(
                "Must be a valid gcs path, i.e., starting with `gs://`"
            )
        return v
must_be_gcs_path(v) classmethod

Validates that the path is a valid gcs path.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
@validator("path")
def must_be_gcs_path(cls, v: str) -> str:
    """Validates that the path is a valid gcs path."""
    if not v.startswith("gs://"):
        raise ValueError(
            "Must be a valid gcs path, i.e., starting with `gs://`"
        )
    return v

io special

gcs_plugin

Plugin which is created to add Google Cloud Store support to ZenML. It inherits from the base Filesystem created by TFX and overwrites the corresponding functions thanks to gcsfs.

ZenGCS (Filesystem)

Filesystem that delegates to Google Cloud Store using gcsfs.

Note: To allow TFX to check for various error conditions, we need to raise their custom NotFoundError instead of the builtin python FileNotFoundError.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
class ZenGCS(Filesystem):
    """Filesystem that delegates to Google Cloud Store using gcsfs.

    **Note**: To allow TFX to check for various error conditions, we need to
    raise their custom `NotFoundError` instead of the builtin python
    FileNotFoundError."""

    SUPPORTED_SCHEMES = ["gs://"]
    fs: gcsfs.GCSFileSystem = None

    @classmethod
    def _ensure_filesystem_set(cls) -> None:
        """Ensures that the filesystem is set."""
        if ZenGCS.fs is None:
            ZenGCS.fs = gcsfs.GCSFileSystem()

    @staticmethod
    def open(path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.
        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently only
                'rb' and 'wb' to read and write binary files are supported.
        """
        ZenGCS._ensure_filesystem_set()

        try:
            return ZenGCS.fs.open(path=path, mode=mode)
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
        """Copy a file.
        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.
        Raises:
            FileNotFoundError: If the source file does not exist.
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        ZenGCS._ensure_filesystem_set()
        if not overwrite and ZenGCS.fs.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )

        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        try:
            ZenGCS.fs.copy(path1=src, path2=dst)
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def exists(path: PathType) -> bool:
        """Check whether a path exists."""
        ZenGCS._ensure_filesystem_set()
        return ZenGCS.fs.exists(path=path)  # type: ignore[no-any-return]

    @staticmethod
    def glob(pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.
        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
          in subdirectories of any depth (e.g. '/some_dir/**/some_file)
        Args:
            pattern: The glob pattern to match, see details above.
        Returns:
            A list of paths that match the given glob pattern.
        """
        ZenGCS._ensure_filesystem_set()
        return ZenGCS.fs.glob(path=pattern)  # type: ignore[no-any-return]

    @staticmethod
    def isdir(path: PathType) -> bool:
        """Check whether a path is a directory."""
        ZenGCS._ensure_filesystem_set()
        return ZenGCS.fs.isdir(path=path)  # type: ignore[no-any-return]

    @staticmethod
    def listdir(path: PathType) -> List[PathType]:
        """Return a list of files in a directory."""
        ZenGCS._ensure_filesystem_set()
        try:
            return ZenGCS.fs.listdir(path=path)  # type: ignore[no-any-return]
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def makedirs(path: PathType) -> None:
        """Create a directory at the given path. If needed also
        create missing parent directories."""
        ZenGCS._ensure_filesystem_set()
        ZenGCS.fs.makedirs(path=path, exist_ok=True)

    @staticmethod
    def mkdir(path: PathType) -> None:
        """Create a directory at the given path."""
        ZenGCS._ensure_filesystem_set()
        ZenGCS.fs.makedir(path=path)

    @staticmethod
    def remove(path: PathType) -> None:
        """Remove the file at the given path."""
        ZenGCS._ensure_filesystem_set()
        try:
            ZenGCS.fs.rm_file(path=path)
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
        """Rename source file to destination file.
        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.
        Raises:
            FileNotFoundError: If the source file does not exist.
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        ZenGCS._ensure_filesystem_set()
        if not overwrite and ZenGCS.fs.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        try:
            ZenGCS.fs.rename(path1=src, path2=dst)
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def rmtree(path: PathType) -> None:
        """Remove the given directory."""
        ZenGCS._ensure_filesystem_set()
        try:
            ZenGCS.fs.delete(path=path, recursive=True)
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def stat(path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path."""
        ZenGCS._ensure_filesystem_set()
        try:
            return ZenGCS.fs.stat(path=path)  # type: ignore[no-any-return]
        except FileNotFoundError as e:
            raise NotFoundError() from e

    @staticmethod
    def walk(
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.
        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.
        Returns:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        ZenGCS._ensure_filesystem_set()
        # TODO [ENG-153]: Additional params
        return ZenGCS.fs.walk(path=top)  # type: ignore[no-any-return]
copy(src, dst, overwrite=False) staticmethod

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileNotFoundError

If the source file does not exist.

FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
    """Copy a file.
    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.
    Raises:
        FileNotFoundError: If the source file does not exist.
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    ZenGCS._ensure_filesystem_set()
    if not overwrite and ZenGCS.fs.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )

    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    try:
        ZenGCS.fs.copy(path1=src, path2=dst)
    except FileNotFoundError as e:
        raise NotFoundError() from e
exists(path) staticmethod

Check whether a path exists.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def exists(path: PathType) -> bool:
    """Check whether a path exists."""
    ZenGCS._ensure_filesystem_set()
    return ZenGCS.fs.exists(path=path)  # type: ignore[no-any-return]
glob(pattern) staticmethod

Return all paths that match the given glob pattern. The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.
    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
      in subdirectories of any depth (e.g. '/some_dir/**/some_file)
    Args:
        pattern: The glob pattern to match, see details above.
    Returns:
        A list of paths that match the given glob pattern.
    """
    ZenGCS._ensure_filesystem_set()
    return ZenGCS.fs.glob(path=pattern)  # type: ignore[no-any-return]
isdir(path) staticmethod

Check whether a path is a directory.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def isdir(path: PathType) -> bool:
    """Check whether a path is a directory."""
    ZenGCS._ensure_filesystem_set()
    return ZenGCS.fs.isdir(path=path)  # type: ignore[no-any-return]
listdir(path) staticmethod

Return a list of files in a directory.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def listdir(path: PathType) -> List[PathType]:
    """Return a list of files in a directory."""
    ZenGCS._ensure_filesystem_set()
    try:
        return ZenGCS.fs.listdir(path=path)  # type: ignore[no-any-return]
    except FileNotFoundError as e:
        raise NotFoundError() from e
makedirs(path) staticmethod

Create a directory at the given path. If needed also create missing parent directories.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def makedirs(path: PathType) -> None:
    """Create a directory at the given path. If needed also
    create missing parent directories."""
    ZenGCS._ensure_filesystem_set()
    ZenGCS.fs.makedirs(path=path, exist_ok=True)
mkdir(path) staticmethod

Create a directory at the given path.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def mkdir(path: PathType) -> None:
    """Create a directory at the given path."""
    ZenGCS._ensure_filesystem_set()
    ZenGCS.fs.makedir(path=path)
open(path, mode='r') staticmethod

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently only 'rb' and 'wb' to read and write binary files are supported.

'r'
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.
    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently only
            'rb' and 'wb' to read and write binary files are supported.
    """
    ZenGCS._ensure_filesystem_set()

    try:
        return ZenGCS.fs.open(path=path, mode=mode)
    except FileNotFoundError as e:
        raise NotFoundError() from e
remove(path) staticmethod

Remove the file at the given path.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def remove(path: PathType) -> None:
    """Remove the file at the given path."""
    ZenGCS._ensure_filesystem_set()
    try:
        ZenGCS.fs.rm_file(path=path)
    except FileNotFoundError as e:
        raise NotFoundError() from e
rename(src, dst, overwrite=False) staticmethod

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileNotFoundError

If the source file does not exist.

FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
    """Rename source file to destination file.
    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.
    Raises:
        FileNotFoundError: If the source file does not exist.
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    ZenGCS._ensure_filesystem_set()
    if not overwrite and ZenGCS.fs.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    try:
        ZenGCS.fs.rename(path1=src, path2=dst)
    except FileNotFoundError as e:
        raise NotFoundError() from e
rmtree(path) staticmethod

Remove the given directory.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def rmtree(path: PathType) -> None:
    """Remove the given directory."""
    ZenGCS._ensure_filesystem_set()
    try:
        ZenGCS.fs.delete(path=path, recursive=True)
    except FileNotFoundError as e:
        raise NotFoundError() from e
stat(path) staticmethod

Return stat info for the given path.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path."""
    ZenGCS._ensure_filesystem_set()
    try:
        return ZenGCS.fs.stat(path=path)  # type: ignore[no-any-return]
    except FileNotFoundError as e:
        raise NotFoundError() from e
walk(top, topdown=True, onerror=None) staticmethod

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Returns:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def walk(
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.
    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.
    Returns:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    ZenGCS._ensure_filesystem_set()
    # TODO [ENG-153]: Additional params
    return ZenGCS.fs.walk(path=top)  # type: ignore[no-any-return]

graphviz special

GraphvizIntegration (Integration)

Definition of Graphviz integration for ZenML.

Source code in zenml/integrations/graphviz/__init__.py
class GraphvizIntegration(Integration):
    """Definition of Graphviz integration for ZenML."""

    NAME = GRAPHVIZ
    REQUIREMENTS = ["graphviz>=0.17"]

visualizers special

pipeline_run_dag_visualizer
PipelineRunDagVisualizer (BasePipelineRunVisualizer)

Visualize the lineage of runs in a pipeline.

Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
class PipelineRunDagVisualizer(BasePipelineRunVisualizer):
    """Visualize the lineage of runs in a pipeline."""

    ARTIFACT_DEFAULT_COLOR = "blue"
    ARTIFACT_CACHED_COLOR = "green"
    ARTIFACT_SHAPE = "box"
    ARTIFACT_PREFIX = "artifact_"
    STEP_COLOR = "#431D93"
    STEP_SHAPE = "ellipse"
    STEP_PREFIX = "step_"
    FONT = "Roboto"

    @abstractmethod
    def visualize(
        self, object: PipelineRunView, *args: Any, **kwargs: Any
    ) -> graphviz.Digraph:
        """Creates a pipeline lineage diagram using graphviz."""
        logger.warning(
            "This integration is not completed yet. Results might be unexpected."
        )

        dot = graphviz.Digraph(comment=object.name)

        # link the steps together
        for step in object.steps:
            # add each step as a node
            dot.node(
                self.STEP_PREFIX + str(step.id),
                step.name,
                shape=self.STEP_SHAPE,
            )
            # for each parent of a step, add an edge

            for artifact_name, artifact in step.outputs.items():
                dot.node(
                    self.ARTIFACT_PREFIX + str(artifact.id),
                    f"{artifact_name} \n" f"({artifact._data_type})",
                    shape=self.ARTIFACT_SHAPE,
                )
                dot.edge(
                    self.STEP_PREFIX + str(step.id),
                    self.ARTIFACT_PREFIX + str(artifact.id),
                )

            for artifact_name, artifact in step.inputs.items():
                dot.edge(
                    self.ARTIFACT_PREFIX + str(artifact.id),
                    self.STEP_PREFIX + str(step.id),
                )

        with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
            dot.render(filename=f.name, format="png", view=True, cleanup=True)
        return dot
visualize(self, object, *args, **kwargs)

Creates a pipeline lineage diagram using graphviz.

Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
@abstractmethod
def visualize(
    self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
    """Creates a pipeline lineage diagram using graphviz."""
    logger.warning(
        "This integration is not completed yet. Results might be unexpected."
    )

    dot = graphviz.Digraph(comment=object.name)

    # link the steps together
    for step in object.steps:
        # add each step as a node
        dot.node(
            self.STEP_PREFIX + str(step.id),
            step.name,
            shape=self.STEP_SHAPE,
        )
        # for each parent of a step, add an edge

        for artifact_name, artifact in step.outputs.items():
            dot.node(
                self.ARTIFACT_PREFIX + str(artifact.id),
                f"{artifact_name} \n" f"({artifact._data_type})",
                shape=self.ARTIFACT_SHAPE,
            )
            dot.edge(
                self.STEP_PREFIX + str(step.id),
                self.ARTIFACT_PREFIX + str(artifact.id),
            )

        for artifact_name, artifact in step.inputs.items():
            dot.edge(
                self.ARTIFACT_PREFIX + str(artifact.id),
                self.STEP_PREFIX + str(step.id),
            )

    with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
        dot.render(filename=f.name, format="png", view=True, cleanup=True)
    return dot

integration

Integration

Base class for integration in ZenML

Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
    """Base class for integration in ZenML"""

    NAME = "base_integration"

    REQUIREMENTS: List[str] = []

    @classmethod
    def check_installation(cls) -> bool:
        """Method to check whether the required packages are installed"""
        try:
            for r in cls.REQUIREMENTS:
                pkg_resources.get_distribution(r)
            logger.debug(
                f"Integration {cls.NAME} is installed correctly with "
                f"requirements {cls.REQUIREMENTS}."
            )
            return True
        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"VersionConflict error when loading installation {cls.NAME}: "
                f"{str(e)}"
            )
            return False

    @staticmethod
    def activate() -> None:
        """Abstract method to activate the integration"""
activate() staticmethod

Abstract method to activate the integration

Source code in zenml/integrations/integration.py
@staticmethod
def activate() -> None:
    """Abstract method to activate the integration"""
check_installation() classmethod

Method to check whether the required packages are installed

Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed"""
    try:
        for r in cls.REQUIREMENTS:
            pkg_resources.get_distribution(r)
        logger.debug(
            f"Integration {cls.NAME} is installed correctly with "
            f"requirements {cls.REQUIREMENTS}."
        )
        return True
    except pkg_resources.DistributionNotFound as e:
        logger.debug(
            f"Unable to find required package '{e.req}' for "
            f"integration {cls.NAME}."
        )
        return False
    except pkg_resources.VersionConflict as e:
        logger.debug(
            f"VersionConflict error when loading installation {cls.NAME}: "
            f"{str(e)}"
        )
        return False

IntegrationMeta (type)

Metaclass responsible for registering different Integration subclasses

Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
    """Metaclass responsible for registering different Integration
    subclasses"""

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "IntegrationMeta":
        """Hook into creation of an Integration class."""
        cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
        if name != "Integration":
            integration_registry.register_integration(cls.NAME, cls)
        return cls
__new__(mcs, name, bases, dct) special staticmethod

Hook into creation of an Integration class.

Source code in zenml/integrations/integration.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
    """Hook into creation of an Integration class."""
    cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
    if name != "Integration":
        integration_registry.register_integration(cls.NAME, cls)
    return cls

kubeflow special

The Kubeflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool.

KubeflowIntegration (Integration)

Definition of Kubeflow Integration for ZenML.

Source code in zenml/integrations/kubeflow/__init__.py
class KubeflowIntegration(Integration):
    """Definition of Kubeflow Integration for ZenML."""

    NAME = KUBEFLOW
    REQUIREMENTS = ["kfp==1.8.9"]

    @classmethod
    def activate(cls) -> None:
        """Activates all classes required for the airflow integration."""
        from zenml.integrations.kubeflow import metadata  # noqa
        from zenml.integrations.kubeflow import orchestrators  # noqa
activate() classmethod

Activates all classes required for the airflow integration.

Source code in zenml/integrations/kubeflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates all classes required for the airflow integration."""
    from zenml.integrations.kubeflow import metadata  # noqa
    from zenml.integrations.kubeflow import orchestrators  # noqa

container_entrypoint

Main entrypoint for containers with Kubeflow TFX component executors.

main()

Runs a single step defined by the command line arguments.

Source code in zenml/integrations/kubeflow/container_entrypoint.py
def main() -> None:
    """Runs a single step defined by the command line arguments."""
    # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
    # the user.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    args = _parse_command_line_arguments()

    tfx_pipeline = pipeline_pb2.Pipeline()
    json_format.Parse(args.tfx_ir, tfx_pipeline)
    _resolve_runtime_parameters(
        tfx_pipeline, args.run_name, args.runtime_parameter
    )

    node_id = args.node_id
    pipeline_node = _get_pipeline_node(tfx_pipeline, node_id)

    deployment_config = runner_utils.extract_local_deployment_config(
        tfx_pipeline
    )
    executor_spec = runner_utils.extract_executor_spec(
        deployment_config, node_id
    )
    custom_driver_spec = runner_utils.extract_custom_driver_spec(
        deployment_config, node_id
    )
    custom_executor_operators = {
        executable_spec_pb2.ContainerExecutableSpec: kubernetes_executor_operator.KubernetesExecutorOperator
    }

    # make sure all integrations are activated so all materializers etc. are
    # available
    integration_registry.activate_integrations()

    metadata_store = Repository().get_active_stack().metadata_store
    if isinstance(metadata_store, KubeflowMetadataStore):
        # set up the metadata connection so it connects to the internal kubeflow
        # mysql database
        connection_config = _get_grpc_metadata_connection_config()
    else:
        connection_config = deployment_config.metadata_connection_config  # type: ignore[attr-defined] # noqa

    metadata_connection = metadata.Metadata(connection_config)

    # import the user main module to register all the materializers
    importlib.import_module(args.main_module)

    if hasattr(executor_spec, "class_path"):
        executor_module_parts = getattr(executor_spec, "class_path").split(".")
        executor_class_target_module_name = ".".join(executor_module_parts[:-1])
        _create_executor_class(
            step_source_module_name=args.step_module,
            step_function_name=args.step_function_name,
            executor_class_target_module_name=executor_class_target_module_name,
            input_artifact_type_mapping=json.loads(args.input_artifact_types),
        )
    else:
        raise RuntimeError(
            f"No class path found inside executor spec: {executor_spec}."
        )

    component_launcher = launcher.Launcher(
        pipeline_node=pipeline_node,
        mlmd_connection=metadata_connection,
        pipeline_info=tfx_pipeline.pipeline_info,
        pipeline_runtime_spec=tfx_pipeline.runtime_spec,
        executor_spec=executor_spec,
        custom_driver_spec=custom_driver_spec,
        custom_executor_operators=custom_executor_operators,
    )
    execution_info = execute_step(component_launcher)

    if execution_info:
        _dump_ui_metadata(pipeline_node, execution_info, args.metadata_ui_path)

docker_utils

build_docker_image(build_context_path, image_name, dockerfile_path=None, dockerignore_path=None, requirements=None, use_local_requirements=False, base_image=None)

Builds a docker image.

Parameters:

Name Type Description Default
build_context_path str

Path to a directory that will be sent to the docker daemon as build context.

required
image_name str

The name to use for the created docker image.

required
dockerfile_path Optional[str]

Optional path to a dockerfile. If no value is given, a temporary dockerfile will be created.

None
dockerignore_path Optional[str]

Optional path to a dockerignore file. If no value is given, the .dockerignore in the root of the build context will be used if it exists. Otherwise, all files inside build_context_path are included in the build context.

None
requirements Optional[List[str]]

Optional list of pip requirements to install. This will only be used if no value is given for dockerfile_path.

None
use_local_requirements bool

If True and no values are given for dockerfile_path and requirements, then the packages installed in the environment of the current python processed will be installed in the docker image.

False
base_image Optional[str]

The image to use as base for the docker image.

None
Source code in zenml/integrations/kubeflow/docker_utils.py
def build_docker_image(
    build_context_path: str,
    image_name: str,
    dockerfile_path: Optional[str] = None,
    dockerignore_path: Optional[str] = None,
    requirements: Optional[List[str]] = None,
    use_local_requirements: bool = False,
    base_image: Optional[str] = None,
) -> None:
    """Builds a docker image.

    Args:
        build_context_path: Path to a directory that will be sent to the
            docker daemon as build context.
        image_name: The name to use for the created docker image.
        dockerfile_path: Optional path to a dockerfile. If no value is given,
            a temporary dockerfile will be created.
        dockerignore_path: Optional path to a dockerignore file. If no value is
            given, the .dockerignore in the root of the build context will be
            used if it exists. Otherwise, all files inside `build_context_path`
            are included in the build context.
        requirements: Optional list of pip requirements to install. This
            will only be used if no value is given for `dockerfile_path`.
        use_local_requirements: If `True` and no values are given for
            `dockerfile_path` and `requirements`, then the packages installed
            in the environment of the current python processed will be
            installed in the docker image.
        base_image: The image to use as base for the docker image.
    """
    if not requirements and use_local_requirements:
        local_requirements = get_current_environment_requirements()
        requirements = [
            f"{package}=={version}"
            for package, version in local_requirements.items()
            if package != "zenml"  # exclude ZenML
        ]
        logger.info(
            "Using requirements from local environment to build "
            "docker image: %s",
            requirements,
        )

    if dockerfile_path:
        dockerfile_contents = zenml.io.utils.read_file_contents_as_string(
            dockerfile_path
        )
    else:
        dockerfile_contents = generate_dockerfile_contents(
            requirements=requirements,
            base_image=base_image or DEFAULT_BASE_IMAGE,
        )

    build_context = create_custom_build_context(
        build_context_path=build_context_path,
        dockerfile_contents=dockerfile_contents,
        dockerignore_path=dockerignore_path,
    )
    # If a custom base image is provided, make sure to always pull the
    # latest version of that image. If no base image is provided, we use
    # the static default ZenML image so there is no need to constantly pull
    always_pull_base_image = bool(base_image)

    logger.info(
        "Building docker image '%s', this might take a while...", image_name
    )
    docker_client = DockerClient.from_env()
    # We use the client api directly here so we can stream the logs
    output_stream = docker_client.images.client.api.build(
        fileobj=build_context,
        custom_context=True,
        tag=image_name,
        pull=always_pull_base_image,
        rm=False,  # don't remove intermediate containers
    )
    _process_stream(output_stream)
    logger.info("Finished building docker image.")
create_custom_build_context(build_context_path, dockerfile_contents, dockerignore_path=None)

Creates a docker build context.

Parameters:

Name Type Description Default
build_context_path str

Path to a directory that will be sent to the docker daemon as build context.

required
dockerfile_contents str

File contents of the Dockerfile to use for the build.

required
dockerignore_path Optional[str]

Optional path to a dockerignore file. If no value is given, the .dockerignore in the root of the build context will be used if it exists. Otherwise, all files inside build_context_path are included in the build context.

None

Returns:

Type Description
Any

Docker build context that can be passed when building a docker image.

Source code in zenml/integrations/kubeflow/docker_utils.py
def create_custom_build_context(
    build_context_path: str,
    dockerfile_contents: str,
    dockerignore_path: Optional[str] = None,
) -> Any:
    """Creates a docker build context.

    Args:
        build_context_path: Path to a directory that will be sent to the
            docker daemon as build context.
        dockerfile_contents: File contents of the Dockerfile to use for the
            build.
        dockerignore_path: Optional path to a dockerignore file. If no value is
            given, the .dockerignore in the root of the build context will be
            used if it exists. Otherwise, all files inside `build_context_path`
            are included in the build context.

    Returns:
        Docker build context that can be passed when building a docker image.
    """
    exclude_patterns = []
    default_dockerignore_path = os.path.join(
        build_context_path, ".dockerignore"
    )
    if dockerignore_path:
        exclude_patterns = _parse_dockerignore(dockerignore_path)
    elif fileio.file_exists(default_dockerignore_path):
        logger.info(
            "Using dockerignore found at path '%s' to create docker "
            "build context.",
            default_dockerignore_path,
        )
        exclude_patterns = _parse_dockerignore(default_dockerignore_path)
    else:
        logger.info(
            "No explicit dockerignore specified and no file called "
            ".dockerignore exists at the build context root (%s)."
            "Creating docker build context with all files inside the build "
            "context root directory.",
            build_context_path,
        )

    logger.debug(
        "Exclude patterns for creating docker build context: %s",
        exclude_patterns,
    )
    no_ignores_found = not exclude_patterns

    files = docker_build_utils.exclude_paths(
        build_context_path, patterns=exclude_patterns
    )
    extra_files = [("Dockerfile", dockerfile_contents)]
    context = docker_build_utils.create_archive(
        root=build_context_path,
        files=sorted(files),
        gzip=False,
        extra_files=extra_files,
    )

    build_context_size = os.path.getsize(context.name)
    if build_context_size > 50 * 1024 * 1024 and no_ignores_found:
        # The build context exceeds 50MiB and we didn't find any excludes
        # in dockerignore files -> remind to specify a .dockerignore file
        logger.warning(
            "Build context size for docker image: %s. If you believe this is "
            "unreasonably large, make sure to include a .dockerignore file at "
            "the root of your build context (%s) or specify a custom file "
            "when defining your pipeline.",
            string_utils.get_human_readable_filesize(build_context_size),
            default_dockerignore_path,
        )

    return context
generate_dockerfile_contents(base_image, command=None, requirements=None)

Generates a Dockerfile.

Parameters:

Name Type Description Default
base_image str

The image to use as base for the dockerfile.

required
command Optional[str]

The default command that gets executed when running a container of an image created by this dockerfile.

None
requirements Optional[List[str]]

Optional list of pip requirements to install.

None

Returns:

Type Description
str

Content of a dockerfile.

Source code in zenml/integrations/kubeflow/docker_utils.py
def generate_dockerfile_contents(
    base_image: str,
    command: Optional[str] = None,
    requirements: Optional[List[str]] = None,
) -> str:
    """Generates a Dockerfile.

    Args:
        base_image: The image to use as base for the dockerfile.
        command: The default command that gets executed when running a
            container of an image created by this dockerfile.
        requirements: Optional list of pip requirements to install.

    Returns:
        Content of a dockerfile.
    """
    lines = [f"FROM {base_image}", "WORKDIR /app"]

    if requirements:
        lines.extend(
            [
                f"RUN pip install --no-cache {' '.join(requirements)}",
            ]
        )

    lines.append("COPY . .")

    if command:
        lines.append(f"CMD {command}")

    return "\n".join(lines)
get_current_environment_requirements()

Returns a dict of package requirements for the environment that the current python process is running in.

Source code in zenml/integrations/kubeflow/docker_utils.py
def get_current_environment_requirements() -> Dict[str, str]:
    """Returns a dict of package requirements for the environment that
    the current python process is running in."""
    return {
        distribution.key: distribution.version
        for distribution in pkg_resources.working_set
    }
get_image_digest(image_name)

Gets the digest of a docker image.

Parameters:

Name Type Description Default
image_name str

Name of the image to get the digest for.

required

Returns:

Type Description
Optional[str]

Returns the repo digest for the given image if there exists exactly one. If there are zero or multiple repo digests, returns None.

Source code in zenml/integrations/kubeflow/docker_utils.py
def get_image_digest(image_name: str) -> Optional[str]:
    """Gets the digest of a docker image.

    Args:
        image_name: Name of the image to get the digest for.

    Returns:
        Returns the repo digest for the given image if there exists exactly one.
        If there are zero or multiple repo digests, returns `None`.
    """
    docker_client = DockerClient.from_env()
    image = docker_client.images.get(image_name)
    repo_digests = image.attrs["RepoDigests"]
    if len(repo_digests) == 1:
        return cast(str, repo_digests[0])
    else:
        logger.debug(
            "Found zero or more repo digests for docker image '%s': %s",
            image_name,
            repo_digests,
        )
        return None
push_docker_image(image_name)

Pushes a docker image to a container registry.

Parameters:

Name Type Description Default
image_name str

The full name (including a tag) of the image to push.

required
Source code in zenml/integrations/kubeflow/docker_utils.py
def push_docker_image(image_name: str) -> None:
    """Pushes a docker image to a container registry.

    Args:
        image_name: The full name (including a tag) of the image to push.
    """
    logger.info("Pushing docker image '%s'.", image_name)
    docker_client = DockerClient.from_env()
    output_stream = docker_client.images.push(image_name, stream=True)
    _process_stream(output_stream)
    logger.info("Finished pushing docker image.")

metadata special

kubeflow_metadata_store
KubeflowMetadataStore (MySQLMetadataStore) pydantic-model

Kubeflow MySQL backend for ZenML metadata store.

Source code in zenml/integrations/kubeflow/metadata/kubeflow_metadata_store.py
class KubeflowMetadataStore(MySQLMetadataStore):
    """Kubeflow MySQL backend for ZenML metadata store."""

    host: str = "127.0.0.1"
    port: int = 3306
    database: str = "metadb"
    username: str = "root"
    password: str = ""

orchestrators special

kubeflow_component

Kubeflow Pipelines based implementation of TFX components. These components are lightweight wrappers around the KFP DSL's ContainerOp, and ensure that the container gets called with the right set of input arguments. It also ensures that each component exports named output attributes that are consistent with those provided by the native TFX components, thus ensuring that both types of pipeline definitions are compatible. Note: This requires Kubeflow Pipelines SDK to be installed.

KubeflowComponent

Base component for all Kubeflow pipelines TFX components. Returns a wrapper around a KFP DSL ContainerOp class, and adds named output attributes that match the output names for the corresponding native TFX components.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_component.py
class KubeflowComponent:
    """Base component for all Kubeflow pipelines TFX components.
    Returns a wrapper around a KFP DSL ContainerOp class, and adds named output
    attributes that match the output names for the corresponding native TFX
    components.
    """

    def __init__(
        self,
        component: tfx_base_component.BaseComponent,
        depends_on: Set[dsl.ContainerOp],
        image: str,
        tfx_ir: pipeline_pb2.Pipeline,  # type: ignore[valid-type]
        pod_labels_to_attach: Dict[str, str],
        main_module: str,
        step_module: str,
        step_function_name: str,
        runtime_parameters: List[data_types.RuntimeParameter],
        metadata_ui_path: str = "/mlpipeline-ui-metadata.json",
    ):
        """Creates a new Kubeflow-based component.
        This class essentially wraps a dsl.ContainerOp construct in Kubeflow
        Pipelines.
        Args:
          component: The logical TFX component to wrap.
          depends_on: The set of upstream KFP ContainerOp components that this
            component will depend on.
          image: The container image to use for this component.
          tfx_ir: The TFX intermedia representation of the pipeline.
          pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
          runtime_parameters: Runtime parameters of the pipeline.
          metadata_ui_path: File location for metadata-ui-metadata.json file.
        """

        utils.replace_placeholder(component)
        input_artifact_type_mapping = _get_input_artifact_type_mapping(
            component
        )

        arguments = [
            "--node_id",
            component.id,
            "--tfx_ir",
            json_format.MessageToJson(tfx_ir),
            "--metadata_ui_path",
            metadata_ui_path,
            "--main_module",
            main_module,
            "--step_module",
            step_module,
            "--step_function_name",
            step_function_name,
            "--input_artifact_types",
            json.dumps(input_artifact_type_mapping),
            "--run_name",
            "{{workflow.annotations.pipelines.kubeflow.org/run_name}}",
        ]

        for param in runtime_parameters:
            arguments.append("--runtime_parameter")
            arguments.append(_encode_runtime_parameter(param))

        repo = Repository()
        artifact_store = repo.get_active_stack().artifact_store
        metadata_store = repo.get_active_stack().metadata_store

        volumes: Dict[str, k8s_client.V1Volume] = {}

        if isinstance(artifact_store, LocalArtifactStore):
            host_path = k8s_client.V1HostPathVolumeSource(
                path=artifact_store.path, type="Directory"
            )
            volumes[artifact_store.path] = k8s_client.V1Volume(
                name="local-artifact-store", host_path=host_path
            )
            logger.debug(
                "Adding host path volume for local artifact store (path: %s) "
                "in kubeflow pipelines container.",
                artifact_store.path,
            )

        if isinstance(metadata_store, SQLiteMetadataStore):
            metadata_store_dir = os.path.dirname(metadata_store.uri)
            host_path = k8s_client.V1HostPathVolumeSource(
                path=metadata_store_dir, type="Directory"
            )
            volumes[metadata_store_dir] = k8s_client.V1Volume(
                name="local-metadata-store", host_path=host_path
            )
            logger.debug(
                "Adding host path volume for local metadata store (uri: %s) "
                "in kubeflow pipelines container.",
                metadata_store.uri,
            )

        self.container_op = dsl.ContainerOp(
            name=component.id,
            command=CONTAINER_ENTRYPOINT_COMMAND,
            image=image,
            arguments=arguments,
            output_artifact_paths={
                "mlpipeline-ui-metadata": metadata_ui_path,
            },
            pvolumes=volumes,
        )

        for op in depends_on:
            self.container_op.after(op)

        self.container_op.container.add_env_variable(
            k8s_client.V1EnvVar(
                name=ENV_ZENML_PREVENT_PIPELINE_EXECUTION, value="True"
            )
        )

        for k, v in pod_labels_to_attach.items():
            self.container_op.add_pod_label(k, v)
__init__(self, component, depends_on, image, tfx_ir, pod_labels_to_attach, main_module, step_module, step_function_name, runtime_parameters, metadata_ui_path='/mlpipeline-ui-metadata.json') special

Creates a new Kubeflow-based component. This class essentially wraps a dsl.ContainerOp construct in Kubeflow Pipelines.

Parameters:

Name Type Description Default
component BaseComponent

The logical TFX component to wrap.

required
depends_on Set[kfp.dsl._container_op.ContainerOp]

The set of upstream KFP ContainerOp components that this component will depend on.

required
image str

The container image to use for this component.

required
tfx_ir Pipeline

The TFX intermedia representation of the pipeline.

required
pod_labels_to_attach Dict[str, str]

Dict of pod labels to attach to the GKE pod.

required
runtime_parameters List[tfx.orchestration.data_types.RuntimeParameter]

Runtime parameters of the pipeline.

required
metadata_ui_path str

File location for metadata-ui-metadata.json file.

'/mlpipeline-ui-metadata.json'
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_component.py
def __init__(
    self,
    component: tfx_base_component.BaseComponent,
    depends_on: Set[dsl.ContainerOp],
    image: str,
    tfx_ir: pipeline_pb2.Pipeline,  # type: ignore[valid-type]
    pod_labels_to_attach: Dict[str, str],
    main_module: str,
    step_module: str,
    step_function_name: str,
    runtime_parameters: List[data_types.RuntimeParameter],
    metadata_ui_path: str = "/mlpipeline-ui-metadata.json",
):
    """Creates a new Kubeflow-based component.
    This class essentially wraps a dsl.ContainerOp construct in Kubeflow
    Pipelines.
    Args:
      component: The logical TFX component to wrap.
      depends_on: The set of upstream KFP ContainerOp components that this
        component will depend on.
      image: The container image to use for this component.
      tfx_ir: The TFX intermedia representation of the pipeline.
      pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
      runtime_parameters: Runtime parameters of the pipeline.
      metadata_ui_path: File location for metadata-ui-metadata.json file.
    """

    utils.replace_placeholder(component)
    input_artifact_type_mapping = _get_input_artifact_type_mapping(
        component
    )

    arguments = [
        "--node_id",
        component.id,
        "--tfx_ir",
        json_format.MessageToJson(tfx_ir),
        "--metadata_ui_path",
        metadata_ui_path,
        "--main_module",
        main_module,
        "--step_module",
        step_module,
        "--step_function_name",
        step_function_name,
        "--input_artifact_types",
        json.dumps(input_artifact_type_mapping),
        "--run_name",
        "{{workflow.annotations.pipelines.kubeflow.org/run_name}}",
    ]

    for param in runtime_parameters:
        arguments.append("--runtime_parameter")
        arguments.append(_encode_runtime_parameter(param))

    repo = Repository()
    artifact_store = repo.get_active_stack().artifact_store
    metadata_store = repo.get_active_stack().metadata_store

    volumes: Dict[str, k8s_client.V1Volume] = {}

    if isinstance(artifact_store, LocalArtifactStore):
        host_path = k8s_client.V1HostPathVolumeSource(
            path=artifact_store.path, type="Directory"
        )
        volumes[artifact_store.path] = k8s_client.V1Volume(
            name="local-artifact-store", host_path=host_path
        )
        logger.debug(
            "Adding host path volume for local artifact store (path: %s) "
            "in kubeflow pipelines container.",
            artifact_store.path,
        )

    if isinstance(metadata_store, SQLiteMetadataStore):
        metadata_store_dir = os.path.dirname(metadata_store.uri)
        host_path = k8s_client.V1HostPathVolumeSource(
            path=metadata_store_dir, type="Directory"
        )
        volumes[metadata_store_dir] = k8s_client.V1Volume(
            name="local-metadata-store", host_path=host_path
        )
        logger.debug(
            "Adding host path volume for local metadata store (uri: %s) "
            "in kubeflow pipelines container.",
            metadata_store.uri,
        )

    self.container_op = dsl.ContainerOp(
        name=component.id,
        command=CONTAINER_ENTRYPOINT_COMMAND,
        image=image,
        arguments=arguments,
        output_artifact_paths={
            "mlpipeline-ui-metadata": metadata_ui_path,
        },
        pvolumes=volumes,
    )

    for op in depends_on:
        self.container_op.after(op)

    self.container_op.container.add_env_variable(
        k8s_client.V1EnvVar(
            name=ENV_ZENML_PREVENT_PIPELINE_EXECUTION, value="True"
        )
    )

    for k, v in pod_labels_to_attach.items():
        self.container_op.add_pod_label(k, v)
kubeflow_dag_runner

The below code is copied from the TFX source repo with minor changes. All credits goes to the TFX team for the core implementation

KubeflowDagRunner (TfxRunner)

Kubeflow Pipelines runner. Constructs a pipeline definition YAML file based on the TFX logical pipeline.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
class KubeflowDagRunner(tfx_runner.TfxRunner):
    """Kubeflow Pipelines runner.
    Constructs a pipeline definition YAML file based on the TFX logical pipeline.
    """

    def __init__(
        self,
        config: KubeflowDagRunnerConfig,
        output_path: str,
        pod_labels_to_attach: Optional[Dict[str, str]] = None,
    ):
        """Initializes KubeflowDagRunner for compiling a Kubeflow Pipeline.
        Args:
          config: A KubeflowDagRunnerConfig object to specify runtime
            configuration when running the pipeline under Kubeflow.
          output_path: Path where the pipeline definition file will be stored.
          pod_labels_to_attach: Optional set of pod labels to attach to GKE pod
            spinned up for this pipeline. Default to the 3 labels:
            1. add-pod-env: true,
            2. pipeline SDK type,
            3. pipeline unique ID,
            where 2 and 3 are instrumentation of usage tracking.
        """
        super().__init__(config)
        self._kubeflow_config = config
        self._output_path = output_path
        self._compiler = compiler.Compiler()
        self._tfx_compiler = tfx_compiler.Compiler()
        self._params: List[dsl.PipelineParam] = []
        self._params_by_component_id: Dict[
            str, List[data_types.RuntimeParameter]
        ] = collections.defaultdict(list)
        self._deduped_parameter_names: Set[str] = set()
        self._pod_labels_to_attach = (
            pod_labels_to_attach or get_default_pod_labels()
        )

    def _parse_parameter_from_component(
        self, component: tfx_base_component.BaseComponent
    ) -> None:
        """Extract embedded RuntimeParameter placeholders from a component.
        Extract embedded RuntimeParameter placeholders from a component, then append
        the corresponding dsl.PipelineParam to KubeflowDagRunner.
        Args:
          component: a TFX component.
        """

        deduped_parameter_names_for_component = set()
        for parameter in component.exec_properties.values():
            if not isinstance(parameter, data_types.RuntimeParameter):
                continue
            # Ignore pipeline root because it will be added later.
            if parameter.name == tfx_pipeline.ROOT_PARAMETER.name:
                continue
            if parameter.name in deduped_parameter_names_for_component:
                continue

            deduped_parameter_names_for_component.add(parameter.name)
            self._params_by_component_id[component.id].append(parameter)
            if parameter.name not in self._deduped_parameter_names:
                self._deduped_parameter_names.add(parameter.name)
                dsl_parameter = dsl.PipelineParam(
                    name=parameter.name, value=str(parameter.default)
                )
                self._params.append(dsl_parameter)

    def _parse_parameter_from_pipeline(
        self, pipeline: tfx_pipeline.Pipeline
    ) -> None:
        """Extract all the RuntimeParameter placeholders from the pipeline."""

        for component in pipeline.components:
            self._parse_parameter_from_component(component)

    def _construct_pipeline_graph(
        self, pipeline: tfx_pipeline.Pipeline
    ) -> None:
        """Constructs a Kubeflow Pipeline graph.
        Args:
          pipeline: The logical TFX pipeline to base the construction on.
          pipeline_root: dsl.PipelineParam representing the pipeline root.
        """
        component_to_kfp_op: Dict[base_node.BaseNode, dsl.ContainerOp] = {}
        tfx_ir = self._generate_tfx_ir(pipeline)

        # Assumption: There is a partial ordering of components in the list,
        # i.e. if component A depends on component B and C, then A appears
        # after B and C in the list.
        for component in pipeline.components:
            # Keep track of the set of upstream dsl.ContainerOps for this
            # component.
            depends_on = set()

            for upstream_component in component.upstream_nodes:
                depends_on.add(component_to_kfp_op[upstream_component])

            # remove the extra pipeline node information
            tfx_node_ir = self._dehydrate_tfx_ir(tfx_ir, component.id)

            from zenml.utils import source_utils

            main_module_file = sys.modules["__main__"].__file__
            main_module = source_utils.get_module_source_from_file_path(
                os.path.abspath(main_module_file)
            )

            step_module = component.component_type.split(".")[:-1]
            if step_module[0] == "__main__":
                step_module = main_module
            else:
                step_module = ".".join(step_module)

            kfp_component = KubeflowComponent(
                main_module=main_module,
                step_module=step_module,
                step_function_name=component.id,
                component=component,
                depends_on=depends_on,
                image=self._kubeflow_config.image,
                pod_labels_to_attach=self._pod_labels_to_attach,
                tfx_ir=tfx_node_ir,
                metadata_ui_path=self._kubeflow_config.metadata_ui_path,
                runtime_parameters=self._params_by_component_id[component.id],
            )

            for operator in self._kubeflow_config.pipeline_operator_funcs:
                kfp_component.container_op.apply(operator)

            component_to_kfp_op[component] = kfp_component.container_op

    def _del_unused_field(
        self, node_id: str, message_dict: MutableMapping[str, Any]
    ) -> None:
        """Remove fields that are not used by the pipeline."""
        for item in list(message_dict.keys()):
            if item != node_id:
                del message_dict[item]

    def _dehydrate_tfx_ir(
        self, original_pipeline: pipeline_pb2.Pipeline, node_id: str  # type: ignore[valid-type] # noqa
    ) -> pipeline_pb2.Pipeline:  # type: ignore[valid-type]
        """Dehydrate the TFX IR to remove unused fields."""
        pipeline = copy.deepcopy(original_pipeline)
        for node in pipeline.nodes:  # type: ignore[attr-defined]
            if (
                node.WhichOneof("node") == "pipeline_node"
                and node.pipeline_node.node_info.id == node_id
            ):
                del pipeline.nodes[:]  # type: ignore[attr-defined]
                pipeline.nodes.extend([node])  # type: ignore[attr-defined]
                break

        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        pipeline.deployment_config.Unpack(deployment_config)  # type: ignore[attr-defined] # noqa
        self._del_unused_field(node_id, deployment_config.executor_specs)
        self._del_unused_field(node_id, deployment_config.custom_driver_specs)
        self._del_unused_field(
            node_id, deployment_config.node_level_platform_configs
        )
        pipeline.deployment_config.Pack(deployment_config)  # type: ignore[attr-defined] # noqa
        return pipeline

    def _generate_tfx_ir(
        self, pipeline: tfx_pipeline.Pipeline
    ) -> Optional[pipeline_pb2.Pipeline]:  # type: ignore[valid-type]
        """Generate the TFX IR from the logical TFX pipeline."""
        result = self._tfx_compiler.compile(pipeline)
        return result

    def run(self, pipeline: tfx_pipeline.Pipeline) -> None:
        """Compiles and outputs a Kubeflow Pipeline YAML definition file.
        Args:
          pipeline: The logical TFX pipeline to use when building the Kubeflow
            pipeline.
        """
        for component in pipeline.components:
            # TODO(b/187122662): Pass through pip dependencies as a first-class
            # component flag.
            if isinstance(component, tfx_base_component.BaseComponent):
                component._resolve_pip_dependencies(
                    # pylint: disable=protected-access
                    pipeline.pipeline_info.pipeline_root
                )

        def _construct_pipeline() -> None:
            """Creates Kubeflow ContainerOps for each TFX component
            encountered in the pipeline definition."""
            self._construct_pipeline_graph(pipeline)

        # Need to run this first to get self._params populated. Then KFP
        # compiler can correctly match default value with PipelineParam.
        self._parse_parameter_from_pipeline(pipeline)
        # Create workflow spec and write out to package.
        self._compiler._create_and_write_workflow(
            # pylint: disable=protected-access
            pipeline_func=_construct_pipeline,
            pipeline_name=pipeline.pipeline_info.pipeline_name,
            params_list=self._params,
            package_path=self._output_path,
        )
        logger.info(
            "Finished writing kubeflow pipeline definition file '%s'.",
            self._output_path,
        )
__init__(self, config, output_path, pod_labels_to_attach=None) special

Initializes KubeflowDagRunner for compiling a Kubeflow Pipeline.

Parameters:

Name Type Description Default
config KubeflowDagRunnerConfig

A KubeflowDagRunnerConfig object to specify runtime configuration when running the pipeline under Kubeflow.

required
output_path str

Path where the pipeline definition file will be stored.

required
pod_labels_to_attach Optional[Dict[str, str]]

Optional set of pod labels to attach to GKE pod spinned up for this pipeline. Default to the 3 labels: 1. add-pod-env: true, 2. pipeline SDK type, 3. pipeline unique ID, where 2 and 3 are instrumentation of usage tracking.

None
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def __init__(
    self,
    config: KubeflowDagRunnerConfig,
    output_path: str,
    pod_labels_to_attach: Optional[Dict[str, str]] = None,
):
    """Initializes KubeflowDagRunner for compiling a Kubeflow Pipeline.
    Args:
      config: A KubeflowDagRunnerConfig object to specify runtime
        configuration when running the pipeline under Kubeflow.
      output_path: Path where the pipeline definition file will be stored.
      pod_labels_to_attach: Optional set of pod labels to attach to GKE pod
        spinned up for this pipeline. Default to the 3 labels:
        1. add-pod-env: true,
        2. pipeline SDK type,
        3. pipeline unique ID,
        where 2 and 3 are instrumentation of usage tracking.
    """
    super().__init__(config)
    self._kubeflow_config = config
    self._output_path = output_path
    self._compiler = compiler.Compiler()
    self._tfx_compiler = tfx_compiler.Compiler()
    self._params: List[dsl.PipelineParam] = []
    self._params_by_component_id: Dict[
        str, List[data_types.RuntimeParameter]
    ] = collections.defaultdict(list)
    self._deduped_parameter_names: Set[str] = set()
    self._pod_labels_to_attach = (
        pod_labels_to_attach or get_default_pod_labels()
    )
run(self, pipeline)

Compiles and outputs a Kubeflow Pipeline YAML definition file.

Parameters:

Name Type Description Default
pipeline Pipeline

The logical TFX pipeline to use when building the Kubeflow pipeline.

required
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def run(self, pipeline: tfx_pipeline.Pipeline) -> None:
    """Compiles and outputs a Kubeflow Pipeline YAML definition file.
    Args:
      pipeline: The logical TFX pipeline to use when building the Kubeflow
        pipeline.
    """
    for component in pipeline.components:
        # TODO(b/187122662): Pass through pip dependencies as a first-class
        # component flag.
        if isinstance(component, tfx_base_component.BaseComponent):
            component._resolve_pip_dependencies(
                # pylint: disable=protected-access
                pipeline.pipeline_info.pipeline_root
            )

    def _construct_pipeline() -> None:
        """Creates Kubeflow ContainerOps for each TFX component
        encountered in the pipeline definition."""
        self._construct_pipeline_graph(pipeline)

    # Need to run this first to get self._params populated. Then KFP
    # compiler can correctly match default value with PipelineParam.
    self._parse_parameter_from_pipeline(pipeline)
    # Create workflow spec and write out to package.
    self._compiler._create_and_write_workflow(
        # pylint: disable=protected-access
        pipeline_func=_construct_pipeline,
        pipeline_name=pipeline.pipeline_info.pipeline_name,
        params_list=self._params,
        package_path=self._output_path,
    )
    logger.info(
        "Finished writing kubeflow pipeline definition file '%s'.",
        self._output_path,
    )
KubeflowDagRunnerConfig (PipelineConfig)

Runtime configuration parameters specific to execution on Kubeflow.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
class KubeflowDagRunnerConfig(pipeline_config.PipelineConfig):
    """Runtime configuration parameters specific to execution on Kubeflow."""

    def __init__(
        self,
        image: str,
        pipeline_operator_funcs: Optional[List[OpFunc]] = None,
        supported_launcher_classes: Optional[
            List[Type[base_component_launcher.BaseComponentLauncher]]
        ] = None,
        metadata_ui_path: str = "/mlpipeline-ui-metadata.json",
        **kwargs: Any
    ):
        """Creates a KubeflowDagRunnerConfig object.
        The user can use pipeline_operator_funcs to apply modifications to
        ContainerOps used in the pipeline. For example, to ensure the pipeline
        steps mount a GCP secret, and a Persistent Volume, one can create config
        object like so:
          from kfp import gcp, onprem
          mount_secret_op = gcp.use_secret('my-secret-name)
          mount_volume_op = onprem.mount_pvc(
            "my-persistent-volume-claim",
            "my-volume-name",
            "/mnt/volume-mount-path")
          config = KubeflowDagRunnerConfig(
            pipeline_operator_funcs=[mount_secret_op, mount_volume_op]
          )
        Args:
          image: The docker image to use in the pipeline.
          pipeline_operator_funcs: A list of ContainerOp modifying functions that
            will be applied to every container step in the pipeline.
          supported_launcher_classes: A list of component launcher classes that are
            supported by the current pipeline. List sequence determines the order in
            which launchers are chosen for each component being run.
          metadata_ui_path: File location for metadata-ui-metadata.json file.
          **kwargs: keyword args for PipelineConfig.
        """
        supported_launcher_classes = supported_launcher_classes or [
            in_process_component_launcher.InProcessComponentLauncher,
            kubernetes_component_launcher.KubernetesComponentLauncher,
        ]
        super().__init__(
            supported_launcher_classes=supported_launcher_classes, **kwargs
        )
        self.pipeline_operator_funcs = (
            pipeline_operator_funcs or get_default_pipeline_operator_funcs()
        )
        self.image = image
        self.metadata_ui_path = metadata_ui_path
__init__(self, image, pipeline_operator_funcs=None, supported_launcher_classes=None, metadata_ui_path='/mlpipeline-ui-metadata.json', **kwargs) special

Creates a KubeflowDagRunnerConfig object. The user can use pipeline_operator_funcs to apply modifications to ContainerOps used in the pipeline. For example, to ensure the pipeline steps mount a GCP secret, and a Persistent Volume, one can create config object like so: from kfp import gcp, onprem mount_secret_op = gcp.use_secret('my-secret-name) mount_volume_op = onprem.mount_pvc( "my-persistent-volume-claim", "my-volume-name", "/mnt/volume-mount-path") config = KubeflowDagRunnerConfig( pipeline_operator_funcs=[mount_secret_op, mount_volume_op] )

Parameters:

Name Type Description Default
image str

The docker image to use in the pipeline.

required
pipeline_operator_funcs Optional[List[Callable[[kfp.dsl._container_op.ContainerOp], Union[kfp.dsl._container_op.ContainerOp, NoneType]]]]

A list of ContainerOp modifying functions that will be applied to every container step in the pipeline.

None
supported_launcher_classes Optional[List[Type[tfx.orchestration.launcher.base_component_launcher.BaseComponentLauncher]]]

A list of component launcher classes that are supported by the current pipeline. List sequence determines the order in which launchers are chosen for each component being run.

None
metadata_ui_path str

File location for metadata-ui-metadata.json file.

'/mlpipeline-ui-metadata.json'
**kwargs Any

keyword args for PipelineConfig.

{}
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def __init__(
    self,
    image: str,
    pipeline_operator_funcs: Optional[List[OpFunc]] = None,
    supported_launcher_classes: Optional[
        List[Type[base_component_launcher.BaseComponentLauncher]]
    ] = None,
    metadata_ui_path: str = "/mlpipeline-ui-metadata.json",
    **kwargs: Any
):
    """Creates a KubeflowDagRunnerConfig object.
    The user can use pipeline_operator_funcs to apply modifications to
    ContainerOps used in the pipeline. For example, to ensure the pipeline
    steps mount a GCP secret, and a Persistent Volume, one can create config
    object like so:
      from kfp import gcp, onprem
      mount_secret_op = gcp.use_secret('my-secret-name)
      mount_volume_op = onprem.mount_pvc(
        "my-persistent-volume-claim",
        "my-volume-name",
        "/mnt/volume-mount-path")
      config = KubeflowDagRunnerConfig(
        pipeline_operator_funcs=[mount_secret_op, mount_volume_op]
      )
    Args:
      image: The docker image to use in the pipeline.
      pipeline_operator_funcs: A list of ContainerOp modifying functions that
        will be applied to every container step in the pipeline.
      supported_launcher_classes: A list of component launcher classes that are
        supported by the current pipeline. List sequence determines the order in
        which launchers are chosen for each component being run.
      metadata_ui_path: File location for metadata-ui-metadata.json file.
      **kwargs: keyword args for PipelineConfig.
    """
    supported_launcher_classes = supported_launcher_classes or [
        in_process_component_launcher.InProcessComponentLauncher,
        kubernetes_component_launcher.KubernetesComponentLauncher,
    ]
    super().__init__(
        supported_launcher_classes=supported_launcher_classes, **kwargs
    )
    self.pipeline_operator_funcs = (
        pipeline_operator_funcs or get_default_pipeline_operator_funcs()
    )
    self.image = image
    self.metadata_ui_path = metadata_ui_path
get_default_pipeline_operator_funcs(use_gcp_sa=False)

Returns a default list of pipeline operator functions.

Parameters:

Name Type Description Default
use_gcp_sa bool

If true, mount a GCP service account secret to each pod, with the name _KUBEFLOW_GCP_SECRET_NAME.

False

Returns:

Type Description
List[Callable[[kfp.dsl._container_op.ContainerOp], Optional[kfp.dsl._container_op.ContainerOp]]]

A list of functions with type OpFunc.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def get_default_pipeline_operator_funcs(
    use_gcp_sa: bool = False,
) -> List[OpFunc]:
    """Returns a default list of pipeline operator functions.
    Args:
      use_gcp_sa: If true, mount a GCP service account secret to each pod, with
        the name _KUBEFLOW_GCP_SECRET_NAME.
    Returns:
      A list of functions with type OpFunc.
    """
    # Enables authentication for GCP services if needed.
    gcp_secret_op = gcp.use_gcp_secret(_KUBEFLOW_GCP_SECRET_NAME)

    # Mounts configmap containing Metadata gRPC server configuration.
    mount_config_map_op = _mount_config_map_op("metadata-grpc-configmap")
    if use_gcp_sa:
        return [gcp_secret_op, mount_config_map_op]
    else:
        return [mount_config_map_op]
get_default_pod_labels()

Returns the default pod label dict for Kubeflow.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def get_default_pod_labels() -> Dict[str, str]:
    """Returns the default pod label dict for Kubeflow."""
    # KFP default transformers add pod env:
    # https://github.com/kubeflow/pipelines/blob/0.1.32/sdk/python/kfp/compiler/_default_transformers.py
    result = {"add-pod-env": "true", telemetry_utils.LABEL_KFP_SDK_ENV: "tfx"}
    return result
kubeflow_orchestrator
KubeflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Kubeflow.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
class KubeflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Kubeflow."""

    custom_docker_base_image_name: Optional[str] = None
    kubeflow_pipelines_ui_port: int = 8080
    kubernetes_context: Optional[str] = None

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag."""

        base_image_name = f"zenml-kubeflow:{pipeline_name}"
        container_registry = Repository().get_active_stack().container_registry

        if container_registry:
            registry_uri = container_registry.uri.rstrip("/")
            return f"{registry_uri}/{base_image_name}"
        else:
            return base_image_name

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all files concerning
        this orchestrator."""
        return os.path.join(
            zenml.io.utils.get_global_config_directory(),
            "kubeflow",
            str(self.uuid),
        )

    @property
    def pipeline_directory(self) -> str:
        """Returns path to a directory in which the kubeflow pipeline files
        are stored."""
        return os.path.join(self.root_directory, "pipelines")

    def pre_run(self, pipeline: "BasePipeline", caller_filepath: str) -> None:
        """Builds a docker image for the current environment and uploads it to
        a container registry if configured.
        """
        from zenml.integrations.kubeflow.docker_utils import (
            build_docker_image,
            push_docker_image,
        )

        image_name = self.get_docker_image_name(pipeline.name)

        repository_root = Repository().path
        requirements = (
            ["kubernetes"]
            + self._get_stack_requirements()
            + self._get_pipeline_requirements(pipeline)
        )
        logger.debug("Kubeflow docker container requirements: %s", requirements)

        build_docker_image(
            build_context_path=repository_root,
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )

        if Repository().get_active_stack().container_registry:
            push_docker_image(image_name)

    def run(
        self,
        zenml_pipeline: "BasePipeline",
        run_name: str,
        **kwargs: Any,
    ) -> None:
        """Runs the pipeline on Kubeflow.

        Args:
            zenml_pipeline: The pipeline to run.
            run_name: Name of the pipeline run.
            **kwargs: Unused kwargs to conform with base signature
        """
        from zenml.integrations.kubeflow.docker_utils import get_image_digest

        image_name = self.get_docker_image_name(zenml_pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        fileio.make_dirs(self.pipeline_directory)
        pipeline_file_path = os.path.join(
            self.pipeline_directory, f"{zenml_pipeline.name}.yaml"
        )
        runner_config = KubeflowDagRunnerConfig(image=image_name)
        runner = KubeflowDagRunner(
            config=runner_config, output_path=pipeline_file_path
        )
        tfx_pipeline = create_tfx_pipeline(zenml_pipeline)
        runner.run(tfx_pipeline)

        run_name = run_name or datetime.now().strftime("%d_%h_%y-%H_%M_%S_%f")
        self._upload_and_run_pipeline(
            pipeline_file_path=pipeline_file_path,
            run_name=run_name,
            enable_cache=zenml_pipeline.enable_cache,
        )

    def _upload_and_run_pipeline(
        self, pipeline_file_path: str, run_name: str, enable_cache: bool
    ) -> None:
        """Tries to upload and run a KFP pipeline.

        Args:
            pipeline_file_path: Path to the pipeline definition file.
            run_name: A name for the pipeline run that will be started.
            enable_cache: Whether caching is enabled for this pipeline run.
        """
        try:
            if self.kubernetes_context:
                logger.info(
                    "Running in kubernetes context '%s'.",
                    self.kubernetes_context,
                )

            # load kubernetes config to authorize the KFP client
            config.load_kube_config(context=self.kubernetes_context)

            # upload the pipeline to Kubeflow and start it
            client = kfp.Client()
            result = client.create_run_from_pipeline_package(
                pipeline_file_path,
                arguments={},
                run_name=run_name,
                enable_caching=enable_cache,
            )
            logger.info("Started pipeline run with ID '%s'.", result.run_id)
        except urllib3.exceptions.HTTPError as error:
            logger.warning(
                "Failed to upload Kubeflow pipeline: %s. "
                "Please make sure your kube config is configured and the "
                "current context is set correctly.",
                error,
            )

    def _get_stack_requirements(self) -> List[str]:
        """Gets list of requirements for the current active stack."""
        stack = Repository().get_active_stack()
        requirements = []

        artifact_store_module = stack.artifact_store.__module__
        requirements += get_requirements_for_module(artifact_store_module)

        metadata_store_module = stack.metadata_store.__module__
        requirements += get_requirements_for_module(metadata_store_module)

        return requirements

    def _get_pipeline_requirements(self, pipeline: "BasePipeline") -> List[str]:
        """Gets list of requirements for a pipeline."""
        if pipeline.requirements_file and fileio.file_exists(
            pipeline.requirements_file
        ):
            logger.debug(
                "Using requirements from file %s.", pipeline.requirements_file
            )
            with fileio.open(pipeline.requirements_file, "r") as f:
                return [
                    requirement.strip() for requirement in f.read().split("\n")
                ]
        else:
            return []

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file."""
        return os.path.join(self.root_directory, "kubeflow_daemon.pid")

    @property
    def _k3d_cluster_name(self) -> str:
        """Returns the K3D cluster name."""
        # K3D only allows cluster names with up to 32 characters, use the
        # first 8 chars of the orchestrator UUID as identifier
        return f"zenml-kubeflow-{str(self.uuid)[:8]}"

    def _get_k3d_registry_name(self, port: int) -> str:
        """Returns the K3D registry name."""
        return f"k3d-zenml-kubeflow-registry.localhost:{port}"

    @property
    def _k3d_registry_config_path(self) -> str:
        """Returns the path to the K3D registry config yaml."""
        return os.path.join(self.root_directory, "k3d_registry.yaml")

    @property
    def is_running(self) -> bool:
        """Returns whether the orchestrator is running."""
        if not local_deployment_utils.check_prerequisites():
            # if any prerequisites are missing there is certainly no
            # local deployment running
            return False

        return local_deployment_utils.k3d_cluster_exists(
            cluster_name=self._k3d_cluster_name
        )

    def up(self) -> None:
        """Spins up a local Kubeflow Pipelines deployment."""
        if self.is_running:
            logger.info(
                "Found already existing local Kubeflow Pipelines deployment. "
                "If there are any issues with the existing deployment, please "
                "run 'zenml orchestrator down' to delete it."
            )
            return

        if not local_deployment_utils.check_prerequisites():
            logger.error(
                "Unable to spin up local Kubeflow Pipelines deployment: "
                "Please install 'k3d' and 'kubectl' and try again."
            )
            return

        container_registry = Repository().get_active_stack().container_registry
        if not container_registry:
            logger.error(
                "Unable to spin up local Kubeflow Pipelines deployment: "
                "Missing container registry in current stack."
            )
            return

        logger.info("Spinning up local Kubeflow Pipelines deployment...")
        fileio.make_dirs(self.root_directory)
        container_registry_port = int(container_registry.uri.split(":")[-1])
        container_registry_name = self._get_k3d_registry_name(
            port=container_registry_port
        )
        local_deployment_utils.write_local_registry_yaml(
            yaml_path=self._k3d_registry_config_path,
            registry_name=container_registry_name,
            registry_uri=container_registry.uri,
        )
        local_deployment_utils.create_k3d_cluster(
            cluster_name=self._k3d_cluster_name,
            registry_name=container_registry_name,
            registry_config_path=self._k3d_registry_config_path,
        )
        kubernetes_context = f"k3d-{self._k3d_cluster_name}"
        local_deployment_utils.deploy_kubeflow_pipelines(
            kubernetes_context=kubernetes_context
        )
        local_deployment_utils.start_kfp_ui_daemon(
            pid_file_path=self._pid_file_path,
            port=self.kubeflow_pipelines_ui_port,
        )

        logger.info(
            f"Finished local Kubeflow Pipelines deployment. The UI should now "
            f"be accessible at "
            f"http://localhost:{self.kubeflow_pipelines_ui_port}/."
        )

    def down(self) -> None:
        """Tears down a local Kubeflow Pipelines deployment."""
        if self.is_running:
            local_deployment_utils.delete_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

        if fileio.file_exists(self._pid_file_path):
            if sys.platform == "win32":
                # Daemon functionality is not supported on Windows, so the PID
                # file won't exist. This if clause exists just for mypy to not
                # complain about missing functions
                pass
            else:
                from zenml.utils import daemon

                daemon.stop_daemon(self._pid_file_path, kill_children=True)
                fileio.remove(self._pid_file_path)

        logger.info("Local kubeflow pipelines deployment spun down.")
is_running: bool property readonly

Returns whether the orchestrator is running.

pipeline_directory: str property readonly

Returns path to a directory in which the kubeflow pipeline files are stored.

root_directory: str property readonly

Returns path to the root directory for all files concerning this orchestrator.

down(self)

Tears down a local Kubeflow Pipelines deployment.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def down(self) -> None:
    """Tears down a local Kubeflow Pipelines deployment."""
    if self.is_running:
        local_deployment_utils.delete_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )

    if fileio.file_exists(self._pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path, kill_children=True)
            fileio.remove(self._pid_file_path)

    logger.info("Local kubeflow pipelines deployment spun down.")
get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag."""

    base_image_name = f"zenml-kubeflow:{pipeline_name}"
    container_registry = Repository().get_active_stack().container_registry

    if container_registry:
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/{base_image_name}"
    else:
        return base_image_name
pre_run(self, pipeline, caller_filepath)

Builds a docker image for the current environment and uploads it to a container registry if configured.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def pre_run(self, pipeline: "BasePipeline", caller_filepath: str) -> None:
    """Builds a docker image for the current environment and uploads it to
    a container registry if configured.
    """
    from zenml.integrations.kubeflow.docker_utils import (
        build_docker_image,
        push_docker_image,
    )

    image_name = self.get_docker_image_name(pipeline.name)

    repository_root = Repository().path
    requirements = (
        ["kubernetes"]
        + self._get_stack_requirements()
        + self._get_pipeline_requirements(pipeline)
    )
    logger.debug("Kubeflow docker container requirements: %s", requirements)

    build_docker_image(
        build_context_path=repository_root,
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )

    if Repository().get_active_stack().container_registry:
        push_docker_image(image_name)
run(self, zenml_pipeline, run_name, **kwargs)

Runs the pipeline on Kubeflow.

Parameters:

Name Type Description Default
zenml_pipeline BasePipeline

The pipeline to run.

required
run_name str

Name of the pipeline run.

required
**kwargs Any

Unused kwargs to conform with base signature

{}
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def run(
    self,
    zenml_pipeline: "BasePipeline",
    run_name: str,
    **kwargs: Any,
) -> None:
    """Runs the pipeline on Kubeflow.

    Args:
        zenml_pipeline: The pipeline to run.
        run_name: Name of the pipeline run.
        **kwargs: Unused kwargs to conform with base signature
    """
    from zenml.integrations.kubeflow.docker_utils import get_image_digest

    image_name = self.get_docker_image_name(zenml_pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    fileio.make_dirs(self.pipeline_directory)
    pipeline_file_path = os.path.join(
        self.pipeline_directory, f"{zenml_pipeline.name}.yaml"
    )
    runner_config = KubeflowDagRunnerConfig(image=image_name)
    runner = KubeflowDagRunner(
        config=runner_config, output_path=pipeline_file_path
    )
    tfx_pipeline = create_tfx_pipeline(zenml_pipeline)
    runner.run(tfx_pipeline)

    run_name = run_name or datetime.now().strftime("%d_%h_%y-%H_%M_%S_%f")
    self._upload_and_run_pipeline(
        pipeline_file_path=pipeline_file_path,
        run_name=run_name,
        enable_cache=zenml_pipeline.enable_cache,
    )
up(self)

Spins up a local Kubeflow Pipelines deployment.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def up(self) -> None:
    """Spins up a local Kubeflow Pipelines deployment."""
    if self.is_running:
        logger.info(
            "Found already existing local Kubeflow Pipelines deployment. "
            "If there are any issues with the existing deployment, please "
            "run 'zenml orchestrator down' to delete it."
        )
        return

    if not local_deployment_utils.check_prerequisites():
        logger.error(
            "Unable to spin up local Kubeflow Pipelines deployment: "
            "Please install 'k3d' and 'kubectl' and try again."
        )
        return

    container_registry = Repository().get_active_stack().container_registry
    if not container_registry:
        logger.error(
            "Unable to spin up local Kubeflow Pipelines deployment: "
            "Missing container registry in current stack."
        )
        return

    logger.info("Spinning up local Kubeflow Pipelines deployment...")
    fileio.make_dirs(self.root_directory)
    container_registry_port = int(container_registry.uri.split(":")[-1])
    container_registry_name = self._get_k3d_registry_name(
        port=container_registry_port
    )
    local_deployment_utils.write_local_registry_yaml(
        yaml_path=self._k3d_registry_config_path,
        registry_name=container_registry_name,
        registry_uri=container_registry.uri,
    )
    local_deployment_utils.create_k3d_cluster(
        cluster_name=self._k3d_cluster_name,
        registry_name=container_registry_name,
        registry_config_path=self._k3d_registry_config_path,
    )
    kubernetes_context = f"k3d-{self._k3d_cluster_name}"
    local_deployment_utils.deploy_kubeflow_pipelines(
        kubernetes_context=kubernetes_context
    )
    local_deployment_utils.start_kfp_ui_daemon(
        pid_file_path=self._pid_file_path,
        port=self.kubeflow_pipelines_ui_port,
    )

    logger.info(
        f"Finished local Kubeflow Pipelines deployment. The UI should now "
        f"be accessible at "
        f"http://localhost:{self.kubeflow_pipelines_ui_port}/."
    )
kubeflow_utils

Common utility for Kubeflow-based orchestrator.

replace_placeholder(component)

Replaces the RuntimeParameter placeholders with kfp.dsl.PipelineParam.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_utils.py
def replace_placeholder(component: base_node.BaseNode) -> None:
    """Replaces the RuntimeParameter placeholders with kfp.dsl.PipelineParam."""
    keys = list(component.exec_properties.keys())
    for key in keys:
        exec_property = component.exec_properties[key]
        if not isinstance(exec_property, data_types.RuntimeParameter):
            continue
        component.exec_properties[key] = str(
            dsl.PipelineParam(name=exec_property.name)
        )
local_deployment_utils
check_prerequisites()

Checks whether all prerequisites for a local kubeflow pipelines deployment are installed.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def check_prerequisites() -> bool:
    """Checks whether all prerequisites for a local kubeflow pipelines
    deployment are installed."""
    k3d_installed = shutil.which("k3d") is not None
    kubectl_installed = shutil.which("kubectl") is not None
    logger.debug(
        "Local kubeflow deployment prerequisites: K3D - %s, Kubectl - %s",
        k3d_installed,
        kubectl_installed,
    )
    return k3d_installed and kubectl_installed
create_k3d_cluster(cluster_name, registry_name, registry_config_path)

Creates a K3D cluster.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to create.

required
registry_name str

Name of the registry to create for this cluster.

required
registry_config_path str

Path to the registry config file.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def create_k3d_cluster(
    cluster_name: str, registry_name: str, registry_config_path: str
) -> None:
    """Creates a K3D cluster.

    Args:
        cluster_name: Name of the cluster to create.
        registry_name: Name of the registry to create for this cluster.
        registry_config_path: Path to the registry config file.
    """
    logger.info("Creating local K3D cluster '%s'.", cluster_name)
    global_config_dir_path = zenml.io.utils.get_global_config_directory()
    subprocess.check_call(
        [
            "k3d",
            "cluster",
            "create",
            cluster_name,
            "--registry-create",
            registry_name,
            "--registry-config",
            registry_config_path,
            "--volume",
            f"{global_config_dir_path}:{global_config_dir_path}",
        ]
    )
    logger.info("Finished K3D cluster creation.")
delete_k3d_cluster(cluster_name)

Deletes a K3D cluster with the given name.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def delete_k3d_cluster(cluster_name: str) -> None:
    """Deletes a K3D cluster with the given name."""
    subprocess.check_call(["k3d", "cluster", "delete", cluster_name])
    logger.info("Deleted local k3d cluster '%s'.", cluster_name)
deploy_kubeflow_pipelines(kubernetes_context)

Deploys Kubeflow Pipelines.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context on which Kubeflow Pipelines should be deployed.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def deploy_kubeflow_pipelines(kubernetes_context: str) -> None:
    """Deploys Kubeflow Pipelines.

    Args:
        kubernetes_context: The kubernetes context on which Kubeflow Pipelines
            should be deployed.
    """
    logger.info("Deploying Kubeflow Pipelines.")
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "apply",
            "-k",
            f"github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}",
        ]
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "wait",
            "--timeout=60s",
            "--for",
            "condition=established",
            "crd/applications.app.k8s.io",
        ]
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "apply",
            "-k",
            f"github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}",
        ]
    )

    logger.info(
        "Waiting for all Kubeflow Pipelines pods to be ready (this might "
        "take a few minutes)."
    )
    while True:
        logger.info("Current pod status:")
        subprocess.check_call(
            [
                "kubectl",
                "--context",
                kubernetes_context,
                "--namespace",
                "kubeflow",
                "get",
                "pods",
            ]
        )
        if kubeflow_pipelines_ready(kubernetes_context=kubernetes_context):
            break

        logger.info("One or more pods not ready yet, waiting for 30 seconds...")
        time.sleep(30)

    logger.info("Finished Kubeflow Pipelines setup.")
k3d_cluster_exists(cluster_name)

Checks whether there exists a K3D cluster with the given name.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_exists(cluster_name: str) -> bool:
    """Checks whether there exists a K3D cluster with the given name."""
    output = subprocess.check_output(
        ["k3d", "cluster", "list", "--output", "json"]
    )
    clusters = json.loads(output)
    for cluster in clusters:
        if cluster["name"] == cluster_name:
            return True
    return False
kubeflow_pipelines_ready(kubernetes_context)

Returns whether all Kubeflow Pipelines pods are ready.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context in which the pods should be checked.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def kubeflow_pipelines_ready(kubernetes_context: str) -> bool:
    """Returns whether all Kubeflow Pipelines pods are ready.

    Args:
        kubernetes_context: The kubernetes context in which the pods
            should be checked.
    """
    try:
        subprocess.check_call(
            [
                "kubectl",
                "--context",
                kubernetes_context,
                "--namespace",
                "kubeflow",
                "wait",
                "--for",
                "condition=ready",
                "--timeout=0s",
                "pods",
                "--all",
            ],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        return True
    except subprocess.CalledProcessError:
        return False
start_kfp_ui_daemon(pid_file_path, port)

Starts a daemon process that forwards ports so the Kubeflow Pipelines UI is accessible in the browser.

Parameters:

Name Type Description Default
pid_file_path str

Path where the file with the daemons process ID should be written.

required
port int

Port on which the UI should be accessible.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_kfp_ui_daemon(pid_file_path: str, port: int) -> None:
    """Starts a daemon process that forwards ports so the Kubeflow Pipelines
    UI is accessible in the browser.

    Args:
        pid_file_path: Path where the file with the daemons process ID should
            be written.
        port: Port on which the UI should be accessible.
    """
    command = [
        "kubectl",
        "--namespace",
        "kubeflow",
        "port-forward",
        "svc/ml-pipeline-ui",
        f"{port}:80",
    ]

    def _daemon_function() -> None:
        """Port-forwards the Kubeflow Pipelines UI pod."""
        subprocess.check_call(command)

    if sys.platform == "win32":
        logger.warning(
            f"Daemon functionality not supported on Windows. "
            f"In order to access the Kubeflow Pipelines UI, please run "
            f"'{' '.join(command)}' in a separate command line shell."
        )
    else:
        from zenml.utils import daemon

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=pid_file_path,
        )
        logger.info("Started Kubeflow Pipelines UI daemon.")
write_local_registry_yaml(yaml_path, registry_name, registry_uri)

Writes a K3D registry config file.

Parameters:

Name Type Description Default
yaml_path str

Path where the config file should be written to.

required
registry_name str

Name of the registry.

required
registry_uri str

URI of the registry.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def write_local_registry_yaml(
    yaml_path: str, registry_name: str, registry_uri: str
) -> None:
    """Writes a K3D registry config file.

    Args:
        yaml_path: Path where the config file should be written to.
        registry_name: Name of the registry.
        registry_uri: URI of the registry.
    """
    yaml_content = {
        "mirrors": {registry_uri: {"endpoint": [f"http://{registry_name}"]}}
    }
    yaml_utils.write_yaml(yaml_path, yaml_content)

plotly special

PlotlyIntegration (Integration)

Definition of Plotly integration for ZenML.

Source code in zenml/integrations/plotly/__init__.py
class PlotlyIntegration(Integration):
    """Definition of Plotly integration for ZenML."""

    NAME = PLOTLY
    REQUIREMENTS = ["plotly>=5.4.0"]

visualizers special

pipeline_lineage_visualizer
PipelineLineageVisualizer (BasePipelineVisualizer)

Visualize the lineage of runs in a pipeline using plotly.

Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
class PipelineLineageVisualizer(BasePipelineVisualizer):
    """Visualize the lineage of runs in a pipeline using plotly."""

    @abstractmethod
    def visualize(
        self, object: PipelineView, *args: Any, **kwargs: Any
    ) -> Figure:
        """Creates a pipeline lineage diagram using plotly."""
        logger.warning(
            "This integration is not completed yet. Results might be unexpected."
        )

        category_dict = {}
        dimensions = ["run"]
        for run in object.runs:
            category_dict[run.name] = {"run": run.name}
            for step in run.steps:
                category_dict[run.name].update(
                    {
                        step.name: str(step.id),
                    }
                )
                if step.name not in dimensions:
                    dimensions.append(f"{step.name}")

        category_df = pd.DataFrame.from_dict(category_dict, orient="index")

        category_df = category_df.reset_index()

        fig = px.parallel_categories(
            category_df,
            dimensions,
            color=None,
            labels="status",
        )

        fig.show()
        return fig
visualize(self, object, *args, **kwargs)

Creates a pipeline lineage diagram using plotly.

Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
@abstractmethod
def visualize(
    self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
    """Creates a pipeline lineage diagram using plotly."""
    logger.warning(
        "This integration is not completed yet. Results might be unexpected."
    )

    category_dict = {}
    dimensions = ["run"]
    for run in object.runs:
        category_dict[run.name] = {"run": run.name}
        for step in run.steps:
            category_dict[run.name].update(
                {
                    step.name: str(step.id),
                }
            )
            if step.name not in dimensions:
                dimensions.append(f"{step.name}")

    category_df = pd.DataFrame.from_dict(category_dict, orient="index")

    category_df = category_df.reset_index()

    fig = px.parallel_categories(
        category_df,
        dimensions,
        color=None,
        labels="status",
    )

    fig.show()
    return fig

pytorch special

PytorchIntegration (Integration)

Definition of PyTorch integration for ZenML.

Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
    """Definition of PyTorch integration for ZenML."""

    NAME = PYTORCH
    REQUIREMENTS = ["torch"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch import materializers  # noqa

materializers special

pytorch_materializer
PyTorchMaterializer (BaseMaterializer)

Materializer to read/write Pytorch models.

Source code in zenml/integrations/pytorch/materializers/pytorch_materializer.py
class PyTorchMaterializer(BaseMaterializer):
    """Materializer to read/write Pytorch models."""

    ASSOCIATED_TYPES = [Module, TorchDict]
    ASSOCIATED_ARTIFACT_TYPES = [ModelArtifact]

    def handle_input(self, data_type: Type[Any]) -> Union[Module, TorchDict]:
        """Reads and returns a PyTorch model.

        Returns:
            A loaded pytorch model.
        """
        super().handle_input(data_type)
        return torch.load(os.path.join(self.artifact.uri, DEFAULT_FILENAME))  # type: ignore[no-untyped-call] # noqa

    def handle_return(self, model: Union[Module, TorchDict]) -> None:
        """Writes a PyTorch model.

        Args:
            model: A torch.nn.Module or a dict to pass into model.save
        """
        super().handle_return(model)
        torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
handle_input(self, data_type)

Reads and returns a PyTorch model.

Returns:

Type Description
Union[torch.nn.modules.module.Module, zenml.integrations.pytorch.materializers.pytorch_types.TorchDict]

A loaded pytorch model.

Source code in zenml/integrations/pytorch/materializers/pytorch_materializer.py
def handle_input(self, data_type: Type[Any]) -> Union[Module, TorchDict]:
    """Reads and returns a PyTorch model.

    Returns:
        A loaded pytorch model.
    """
    super().handle_input(data_type)
    return torch.load(os.path.join(self.artifact.uri, DEFAULT_FILENAME))  # type: ignore[no-untyped-call] # noqa
handle_return(self, model)

Writes a PyTorch model.

Parameters:

Name Type Description Default
model Union[torch.nn.modules.module.Module, zenml.integrations.pytorch.materializers.pytorch_types.TorchDict]

A torch.nn.Module or a dict to pass into model.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_materializer.py
def handle_return(self, model: Union[Module, TorchDict]) -> None:
    """Writes a PyTorch model.

    Args:
        model: A torch.nn.Module or a dict to pass into model.save
    """
    super().handle_return(model)
    torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
pytorch_types
TorchDict (dict, Generic)

A type of dict that represents saving a model.

Source code in zenml/integrations/pytorch/materializers/pytorch_types.py
class TorchDict(Dict[str, Any]):
    """A type of dict that represents saving a model."""

pytorch_lightning special

PytorchLightningIntegration (Integration)

Definition of PyTorch Lightning integration for ZenML.

Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
    """Definition of PyTorch Lightning integration for ZenML."""

    NAME = PYTORCH_L
    REQUIREMENTS = ["pytorch_lightning"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch_lightning import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch_lightning import materializers  # noqa

materializers special

pytorch_lightning_materializer
PyTorchLightningMaterializer (BaseMaterializer)

Materializer to read/write Pytorch models.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
    """Materializer to read/write Pytorch models."""

    ASSOCIATED_TYPES = [Trainer]
    ASSOCIATED_ARTIFACT_TYPES = [ModelArtifact]

    def handle_input(self, data_type: Type[Any]) -> Trainer:
        """Reads and returns a PyTorch Lightning trainer.

        Returns:
            A PyTorch Lightning trainer object.
        """
        super().handle_input(data_type)
        return Trainer(
            resume_from_checkpoint=os.path.join(
                self.artifact.uri, CHECKPOINT_NAME
            )
        )

    def handle_return(self, trainer: Trainer) -> None:
        """Writes a PyTorch Lightning trainer.

        Args:
            trainer: A PyTorch Lightning trainer object.
        """
        super().handle_return(trainer)
        trainer.save_checkpoint(
            os.path.join(self.artifact.uri, CHECKPOINT_NAME)
        )
handle_input(self, data_type)

Reads and returns a PyTorch Lightning trainer.

Returns:

Type Description
Trainer

A PyTorch Lightning trainer object.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_input(self, data_type: Type[Any]) -> Trainer:
    """Reads and returns a PyTorch Lightning trainer.

    Returns:
        A PyTorch Lightning trainer object.
    """
    super().handle_input(data_type)
    return Trainer(
        resume_from_checkpoint=os.path.join(
            self.artifact.uri, CHECKPOINT_NAME
        )
    )
handle_return(self, trainer)

Writes a PyTorch Lightning trainer.

Parameters:

Name Type Description Default
trainer Trainer

A PyTorch Lightning trainer object.

required
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_return(self, trainer: Trainer) -> None:
    """Writes a PyTorch Lightning trainer.

    Args:
        trainer: A PyTorch Lightning trainer object.
    """
    super().handle_return(trainer)
    trainer.save_checkpoint(
        os.path.join(self.artifact.uri, CHECKPOINT_NAME)
    )

registry

IntegrationRegistry

Registry to keep track of ZenML Integrations

Source code in zenml/integrations/registry.py
class IntegrationRegistry(object):
    """Registry to keep track of ZenML Integrations"""

    def __init__(self) -> None:
        """Initializing the integration registry"""
        self._integrations: Dict[str, Type["Integration"]] = {}

    @property
    def integrations(self) -> Dict[str, Type["Integration"]]:
        """Method to get integrations dictionary.

        Returns:
            A dict of integration key to type of `Integration`.
        """
        return self._integrations

    @integrations.setter
    def integrations(self, i: Any) -> None:
        """Setter method for the integrations property"""
        raise IntegrationError(
            "Please do not manually change the integrations within the "
            "registry. If you would like to register a new integration "
            "manually, please use "
            "`integration_registry.register_integration()`."
        )

    def register_integration(
        self, key: str, type_: Type["Integration"]
    ) -> None:
        """Method to register an integration with a given name"""
        self._integrations[key] = type_

    def activate_integrations(self) -> None:
        """Method to activate the integrations with are registered in the
        registry"""
        for name, integration in self._integrations.items():
            if integration.check_installation():
                integration.activate()
                logger.debug(f"Integration `{name}` is activated.")
            else:
                logger.debug(f"Integration `{name}` could not be activated.")

    @property
    def list_integration_names(self) -> List[str]:
        """Get a list of all possible integrations"""
        return [name for name in self._integrations]

    def select_integration_requirements(
        self, integration_name: Optional[str] = None
    ) -> List[str]:
        """Select the requirements for a given integration
        or all integrations"""
        if integration_name:
            if integration_name in self.list_integration_names:
                return self._integrations[integration_name].REQUIREMENTS
            else:
                raise KeyError(
                    f"Version {integration_name} does not exist. "
                    f"Currently the following integrations are implemented. "
                    f"{self.list_integration_names}"
                )
        else:
            return [
                requirement
                for name in self.list_integration_names
                for requirement in self._integrations[name].REQUIREMENTS
            ]

    def is_installed(self, integration_name: Optional[str] = None) -> bool:
        """Checks if all requirements for an integration are installed"""
        if integration_name in self.list_integration_names:
            return self._integrations[integration_name].check_installation()
        elif not integration_name:
            all_installed = [
                self._integrations[item].check_installation()
                for item in self.list_integration_names
            ]
            return any(all_installed)
        else:
            raise KeyError(
                f"Integration '{integration_name}' not found. "
                f"Currently the following integrations are available: "
                f"{self.list_integration_names}"
            )
integrations: Dict[str, Type[Integration]] property writable

Method to get integrations dictionary.

Returns:

Type Description
Dict[str, Type[Integration]]

A dict of integration key to type of Integration.

list_integration_names: List[str] property readonly

Get a list of all possible integrations

__init__(self) special

Initializing the integration registry

Source code in zenml/integrations/registry.py
def __init__(self) -> None:
    """Initializing the integration registry"""
    self._integrations: Dict[str, Type["Integration"]] = {}
activate_integrations(self)

Method to activate the integrations with are registered in the registry

Source code in zenml/integrations/registry.py
def activate_integrations(self) -> None:
    """Method to activate the integrations with are registered in the
    registry"""
    for name, integration in self._integrations.items():
        if integration.check_installation():
            integration.activate()
            logger.debug(f"Integration `{name}` is activated.")
        else:
            logger.debug(f"Integration `{name}` could not be activated.")
is_installed(self, integration_name=None)

Checks if all requirements for an integration are installed

Source code in zenml/integrations/registry.py
def is_installed(self, integration_name: Optional[str] = None) -> bool:
    """Checks if all requirements for an integration are installed"""
    if integration_name in self.list_integration_names:
        return self._integrations[integration_name].check_installation()
    elif not integration_name:
        all_installed = [
            self._integrations[item].check_installation()
            for item in self.list_integration_names
        ]
        return any(all_installed)
    else:
        raise KeyError(
            f"Integration '{integration_name}' not found. "
            f"Currently the following integrations are available: "
            f"{self.list_integration_names}"
        )
register_integration(self, key, type_)

Method to register an integration with a given name

Source code in zenml/integrations/registry.py
def register_integration(
    self, key: str, type_: Type["Integration"]
) -> None:
    """Method to register an integration with a given name"""
    self._integrations[key] = type_
select_integration_requirements(self, integration_name=None)

Select the requirements for a given integration or all integrations

Source code in zenml/integrations/registry.py
def select_integration_requirements(
    self, integration_name: Optional[str] = None
) -> List[str]:
    """Select the requirements for a given integration
    or all integrations"""
    if integration_name:
        if integration_name in self.list_integration_names:
            return self._integrations[integration_name].REQUIREMENTS
        else:
            raise KeyError(
                f"Version {integration_name} does not exist. "
                f"Currently the following integrations are implemented. "
                f"{self.list_integration_names}"
            )
    else:
        return [
            requirement
            for name in self.list_integration_names
            for requirement in self._integrations[name].REQUIREMENTS
        ]

sklearn special

SklearnIntegration (Integration)

Definition of sklearn integration for ZenML.

Source code in zenml/integrations/sklearn/__init__.py
class SklearnIntegration(Integration):
    """Definition of sklearn integration for ZenML."""

    NAME = SKLEARN
    REQUIREMENTS = ["scikit-learn"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.sklearn import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/sklearn/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.sklearn import materializers  # noqa

helpers special

digits
get_digits()

Returns the digits dataset in the form of a tuple of numpy arrays.

Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Returns the digits dataset in the form of a tuple of numpy
    arrays."""
    digits = load_digits()
    # flatten the images
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    # Split data into 50% train and 50% test subsets
    X_train, X_test, y_train, y_test = train_test_split(
        data, digits.target, test_size=0.5, shuffle=False
    )
    return X_train, X_test, y_train, y_test
get_digits_model()

Creates a support vector classifier for digits dataset.

Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits_model() -> ClassifierMixin:
    """Creates a support vector classifier for digits dataset."""
    return SVC(gamma=0.001)

materializers special

sklearn_materializer
SklearnMaterializer (BaseMaterializer)

Materializer to read data to and from sklearn.

Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
class SklearnMaterializer(BaseMaterializer):
    """Materializer to read data to and from sklearn."""

    ASSOCIATED_TYPES = [
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ]
    ASSOCIATED_ARTIFACT_TYPES = [ModelArtifact]

    def handle_input(
        self, data_type: Type[Any]
    ) -> Union[
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ]:
        """Reads a base sklearn model from a pickle file."""
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        with fileio.open(filepath, "rb") as fid:
            clf = pickle.load(fid)
        return clf

    def handle_return(
        self,
        clf: Union[
            BaseEstimator,
            ClassifierMixin,
            ClusterMixin,
            BiclusterMixin,
            OutlierMixin,
            RegressorMixin,
            MetaEstimatorMixin,
            MultiOutputMixin,
            DensityMixin,
            TransformerMixin,
        ],
    ) -> None:
        """Creates a pickle for a sklearn model.

        Args:
            clf: A sklearn model.
        """
        super().handle_return(clf)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        with fileio.open(filepath, "wb") as fid:
            pickle.dump(clf, fid)
handle_input(self, data_type)

Reads a base sklearn model from a pickle file.

Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_input(
    self, data_type: Type[Any]
) -> Union[
    BaseEstimator,
    ClassifierMixin,
    ClusterMixin,
    BiclusterMixin,
    OutlierMixin,
    RegressorMixin,
    MetaEstimatorMixin,
    MultiOutputMixin,
    DensityMixin,
    TransformerMixin,
]:
    """Reads a base sklearn model from a pickle file."""
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    with fileio.open(filepath, "rb") as fid:
        clf = pickle.load(fid)
    return clf
handle_return(self, clf)

Creates a pickle for a sklearn model.

Parameters:

Name Type Description Default
clf Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin]

A sklearn model.

required
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_return(
    self,
    clf: Union[
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ],
) -> None:
    """Creates a pickle for a sklearn model.

    Args:
        clf: A sklearn model.
    """
    super().handle_return(clf)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    with fileio.open(filepath, "wb") as fid:
        pickle.dump(clf, fid)

steps special

sklearn_evaluator
SklearnEvaluator (BaseEvaluatorStep)

A simple step implementation which utilizes sklearn to evaluate the performance of a given model on a given test dataset

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluator(BaseEvaluatorStep):
    """A simple step implementation which utilizes sklearn to evaluate the
    performance of a given model on a given test dataset"""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        model: tf.keras.Model,
        config: SklearnEvaluatorConfig,
    ) -> dict:  # type: ignore[type-arg]
        """Method which is responsible for the computation of the evaluation

        Args:
            dataset: a pandas Dataframe which represents the test dataset
            model: a trained tensorflow Keras model
            config: the configuration for the step
        Returns:
            a dictionary which has the evaluation report
        """
        labels = dataset.pop(config.label_class_column)

        predictions = model.predict(dataset)
        predicted_classes = [1 if v > 0.5 else 0 for v in predictions]

        report = classification_report(
            labels, predicted_classes, output_dict=True
        )

        return report  # type: ignore[no-any-return]
CONFIG_CLASS (BaseEvaluatorConfig) pydantic-model

Config class for the sklearn evaluator

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
    """Config class for the sklearn evaluator"""

    label_class_column: str
entrypoint(self, dataset, model, config)

Method which is responsible for the computation of the evaluation

Parameters:

Name Type Description Default
dataset DataFrame

a pandas Dataframe which represents the test dataset

required
model Model

a trained tensorflow Keras model

required
config SklearnEvaluatorConfig

the configuration for the step

required

Returns:

Type Description
dict

a dictionary which has the evaluation report

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    model: tf.keras.Model,
    config: SklearnEvaluatorConfig,
) -> dict:  # type: ignore[type-arg]
    """Method which is responsible for the computation of the evaluation

    Args:
        dataset: a pandas Dataframe which represents the test dataset
        model: a trained tensorflow Keras model
        config: the configuration for the step
    Returns:
        a dictionary which has the evaluation report
    """
    labels = dataset.pop(config.label_class_column)

    predictions = model.predict(dataset)
    predicted_classes = [1 if v > 0.5 else 0 for v in predictions]

    report = classification_report(
        labels, predicted_classes, output_dict=True
    )

    return report  # type: ignore[no-any-return]
SklearnEvaluatorConfig (BaseEvaluatorConfig) pydantic-model

Config class for the sklearn evaluator

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
    """Config class for the sklearn evaluator"""

    label_class_column: str
sklearn_splitter
SklearnSplitter (BaseSplitStep)

A simple step implementation which utilizes sklearn to split a given dataset into train, test and validation splits

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitter(BaseSplitStep):
    """A simple step implementation which utilizes sklearn to split a given
    dataset into train, test and validation splits"""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: SklearnSplitterConfig,
    ) -> Output(  # type:ignore[valid-type]
        train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
    ):
        """Method which is responsible for the splitting logic

        Args:
            dataset: a pandas Dataframe which entire dataset
            config: the configuration for the step
        Returns:
            three dataframes representing the splits
        """
        if (
            any(
                [
                    split not in config.ratios
                    for split in ["train", "test", "validation"]
                ]
            )
            or len(config.ratios) != 3
        ):
            raise KeyError(
                f"Make sure that you only use 'train', 'test' and "
                f"'validation' as keys in the ratios dict. Current keys: "
                f"{config.ratios.keys()}"
            )

        if sum(config.ratios.values()) != 1:
            raise ValueError(
                f"Make sure that the ratios sum up to 1. Current "
                f"ratios: {config.ratios}"
            )

        train_dataset, test_dataset = train_test_split(
            dataset, test_size=config.ratios["test"]
        )

        train_dataset, val_dataset = train_test_split(
            train_dataset,
            test_size=(
                config.ratios["validation"]
                / (config.ratios["validation"] + config.ratios["train"])
            ),
        )

        return train_dataset, test_dataset, val_dataset
CONFIG_CLASS (BaseSplitStepConfig) pydantic-model

Config class for the sklearn splitter

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
    """Config class for the sklearn splitter"""

    ratios: Dict[str, float]
entrypoint(self, dataset, config)

Method which is responsible for the splitting logic

Parameters:

Name Type Description Default
dataset DataFrame

a pandas Dataframe which entire dataset

required
config SklearnSplitterConfig

the configuration for the step

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7f372cedac70>

three dataframes representing the splits

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: SklearnSplitterConfig,
) -> Output(  # type:ignore[valid-type]
    train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
    """Method which is responsible for the splitting logic

    Args:
        dataset: a pandas Dataframe which entire dataset
        config: the configuration for the step
    Returns:
        three dataframes representing the splits
    """
    if (
        any(
            [
                split not in config.ratios
                for split in ["train", "test", "validation"]
            ]
        )
        or len(config.ratios) != 3
    ):
        raise KeyError(
            f"Make sure that you only use 'train', 'test' and "
            f"'validation' as keys in the ratios dict. Current keys: "
            f"{config.ratios.keys()}"
        )

    if sum(config.ratios.values()) != 1:
        raise ValueError(
            f"Make sure that the ratios sum up to 1. Current "
            f"ratios: {config.ratios}"
        )

    train_dataset, test_dataset = train_test_split(
        dataset, test_size=config.ratios["test"]
    )

    train_dataset, val_dataset = train_test_split(
        train_dataset,
        test_size=(
            config.ratios["validation"]
            / (config.ratios["validation"] + config.ratios["train"])
        ),
    )

    return train_dataset, test_dataset, val_dataset
SklearnSplitterConfig (BaseSplitStepConfig) pydantic-model

Config class for the sklearn splitter

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
    """Config class for the sklearn splitter"""

    ratios: Dict[str, float]
sklearn_standard_scaler
SklearnStandardScaler (BasePreprocesserStep)

Simple step implementation which utilizes the StandardScaler from sklearn to transform the numeric columns of a pd.DataFrame

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScaler(BasePreprocesserStep):
    """Simple step implementation which utilizes the StandardScaler from sklearn
    to transform the numeric columns of a pd.DataFrame"""

    def entrypoint(  # type: ignore[override]
        self,
        train_dataset: pd.DataFrame,
        test_dataset: pd.DataFrame,
        validation_dataset: pd.DataFrame,
        statistics: pd.DataFrame,
        schema: pd.DataFrame,
        config: SklearnStandardScalerConfig,
    ) -> Output(  # type:ignore[valid-type]
        train_transformed=pd.DataFrame,
        test_transformed=pd.DataFrame,
        valdation_transformed=pd.DataFrame,
    ):
        """Main entrypoint function for the StandardScaler

        Args:
            train_dataset: pd.DataFrame, the training dataset
            test_dataset: pd.DataFrame, the test dataset
            validation_dataset: pd.DataFrame, the validation dataset
            statistics: pd.DataFrame, the statistics over the train dataset
            schema: pd.DataFrame, the detected schema of the dataset
            config: the configuration for the step
        Returns:
             the transformed train, test and validation datasets as
             pd.DataFrames
        """
        schema_dict = {k: v[0] for k, v in schema.to_dict().items()}

        # Exclude columns
        feature_set = set(train_dataset.columns) - set(config.exclude_columns)
        for feature, feature_type in schema_dict.items():
            if feature_type != "int64" and feature_type != "float64":
                feature_set.remove(feature)
                logger.warning(
                    f"{feature} column is a not numeric, thus it is excluded "
                    f"from the standard scaling."
                )

        transform_feature_set = feature_set - set(config.ignore_columns)

        # Transform the datasets
        scaler = StandardScaler()
        scaler.mean_ = statistics["mean"][transform_feature_set]
        scaler.scale_ = statistics["std"][transform_feature_set]

        train_dataset[list(transform_feature_set)] = scaler.transform(
            train_dataset[transform_feature_set]
        )
        test_dataset[list(transform_feature_set)] = scaler.transform(
            test_dataset[transform_feature_set]
        )
        validation_dataset[list(transform_feature_set)] = scaler.transform(
            validation_dataset[transform_feature_set]
        )

        return train_dataset, test_dataset, validation_dataset
CONFIG_CLASS (BasePreprocesserConfig) pydantic-model

Config class for the sklearn standard scaler

ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocesserConfig):
    """Config class for the sklearn standard scaler

    ignore_columns: a list of column names which should not be scaled
    exclude_columns: a list of column names to be excluded from the dataset
    """

    ignore_columns: List[str] = []
    exclude_columns: List[str] = []
entrypoint(self, train_dataset, test_dataset, validation_dataset, statistics, schema, config)

Main entrypoint function for the StandardScaler

Parameters:

Name Type Description Default
train_dataset DataFrame

pd.DataFrame, the training dataset

required
test_dataset DataFrame

pd.DataFrame, the test dataset

required
validation_dataset DataFrame

pd.DataFrame, the validation dataset

required
statistics DataFrame

pd.DataFrame, the statistics over the train dataset

required
schema DataFrame

pd.DataFrame, the detected schema of the dataset

required
config SklearnStandardScalerConfig

the configuration for the step

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7f372cedbfd0>

the transformed train, test and validation datasets as pd.DataFrames

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
def entrypoint(  # type: ignore[override]
    self,
    train_dataset: pd.DataFrame,
    test_dataset: pd.DataFrame,
    validation_dataset: pd.DataFrame,
    statistics: pd.DataFrame,
    schema: pd.DataFrame,
    config: SklearnStandardScalerConfig,
) -> Output(  # type:ignore[valid-type]
    train_transformed=pd.DataFrame,
    test_transformed=pd.DataFrame,
    valdation_transformed=pd.DataFrame,
):
    """Main entrypoint function for the StandardScaler

    Args:
        train_dataset: pd.DataFrame, the training dataset
        test_dataset: pd.DataFrame, the test dataset
        validation_dataset: pd.DataFrame, the validation dataset
        statistics: pd.DataFrame, the statistics over the train dataset
        schema: pd.DataFrame, the detected schema of the dataset
        config: the configuration for the step
    Returns:
         the transformed train, test and validation datasets as
         pd.DataFrames
    """
    schema_dict = {k: v[0] for k, v in schema.to_dict().items()}

    # Exclude columns
    feature_set = set(train_dataset.columns) - set(config.exclude_columns)
    for feature, feature_type in schema_dict.items():
        if feature_type != "int64" and feature_type != "float64":
            feature_set.remove(feature)
            logger.warning(
                f"{feature} column is a not numeric, thus it is excluded "
                f"from the standard scaling."
            )

    transform_feature_set = feature_set - set(config.ignore_columns)

    # Transform the datasets
    scaler = StandardScaler()
    scaler.mean_ = statistics["mean"][transform_feature_set]
    scaler.scale_ = statistics["std"][transform_feature_set]

    train_dataset[list(transform_feature_set)] = scaler.transform(
        train_dataset[transform_feature_set]
    )
    test_dataset[list(transform_feature_set)] = scaler.transform(
        test_dataset[transform_feature_set]
    )
    validation_dataset[list(transform_feature_set)] = scaler.transform(
        validation_dataset[transform_feature_set]
    )

    return train_dataset, test_dataset, validation_dataset
SklearnStandardScalerConfig (BasePreprocesserConfig) pydantic-model

Config class for the sklearn standard scaler

ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocesserConfig):
    """Config class for the sklearn standard scaler

    ignore_columns: a list of column names which should not be scaled
    exclude_columns: a list of column names to be excluded from the dataset
    """

    ignore_columns: List[str] = []
    exclude_columns: List[str] = []

tensorflow special

TensorflowIntegration (Integration)

Definition of Tensorflow integration for ZenML.

Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
    """Definition of Tensorflow integration for ZenML."""

    NAME = TENSORFLOW
    REQUIREMENTS = ["tensorflow"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.tensorflow import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.tensorflow import materializers  # noqa

materializers special

keras_materializer
KerasMaterializer (BaseMaterializer)

Materializer to read/write Keras models.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
    """Materializer to read/write Keras models."""

    ASSOCIATED_TYPES = [keras.Model]
    ASSOCIATED_ARTIFACT_TYPES = [ModelArtifact]

    def handle_input(self, data_type: Type[Any]) -> keras.Model:
        """Reads and returns a Keras model.

        Returns:
            A tf.keras.Model model.
        """
        super().handle_input(data_type)
        return keras.models.load_model(self.artifact.uri)

    def handle_return(self, model: keras.Model) -> None:
        """Writes a keras model.

        Args:
            model: A tf.keras.Model model.
        """
        super().handle_return(model)
        model.save(self.artifact.uri)
handle_input(self, data_type)

Reads and returns a Keras model.

Returns:

Type Description
Model

A tf.keras.Model model.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_input(self, data_type: Type[Any]) -> keras.Model:
    """Reads and returns a Keras model.

    Returns:
        A tf.keras.Model model.
    """
    super().handle_input(data_type)
    return keras.models.load_model(self.artifact.uri)
handle_return(self, model)

Writes a keras model.

Parameters:

Name Type Description Default
model Model

A tf.keras.Model model.

required
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_return(self, model: keras.Model) -> None:
    """Writes a keras model.

    Args:
        model: A tf.keras.Model model.
    """
    super().handle_return(model)
    model.save(self.artifact.uri)
tf_dataset_materializer
TensorflowDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from tf.data.Dataset.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from tf.data.Dataset."""

    ASSOCIATED_TYPES = [tf.data.Dataset]
    ASSOCIATED_ARTIFACT_TYPES = [DataArtifact]

    def handle_input(self, data_type: Type[Any]) -> Any:
        """Reads data into tf.data.Dataset"""
        super().handle_input(data_type)
        path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        return tf.data.experimental.load(path)

    def handle_return(self, dataset: tf.data.Dataset) -> None:
        """Persists a tf.data.Dataset object."""
        super().handle_return(dataset)
        path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        tf.data.experimental.save(
            dataset, path, compression=None, shard_func=None
        )
handle_input(self, data_type)

Reads data into tf.data.Dataset

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Any:
    """Reads data into tf.data.Dataset"""
    super().handle_input(data_type)
    path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    return tf.data.experimental.load(path)
handle_return(self, dataset)

Persists a tf.data.Dataset object.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_return(self, dataset: tf.data.Dataset) -> None:
    """Persists a tf.data.Dataset object."""
    super().handle_return(dataset)
    path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    tf.data.experimental.save(
        dataset, path, compression=None, shard_func=None
    )

steps special

tensorflow_trainer
TensorflowBinaryClassifier (BaseTrainerStep)

Simple step implementation which creates a simple tensorflow feedforward neural network and trains it on a given pd.DataFrame dataset

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifier(BaseTrainerStep):
    """Simple step implementation which creates a simple tensorflow feedforward
    neural network and trains it on a given pd.DataFrame dataset
    """

    def entrypoint(  # type: ignore[override]
        self,
        train_dataset: pd.DataFrame,
        validation_dataset: pd.DataFrame,
        config: TensorflowBinaryClassifierConfig,
    ) -> tf.keras.Model:
        """Main entrypoint for the tensorflow trainer

        Args:
            train_dataset: pd.DataFrame, the training dataset
            validation_dataset: pd.DataFrame, the validation dataset
            config: the configuration of the step
        Returns:
            the trained tf.keras.Model
        """
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
        model.add(tf.keras.layers.Flatten())

        last_layer = config.layers.pop()
        for i, layer in enumerate(config.layers):
            model.add(tf.keras.layers.Dense(layer, activation="relu"))
        model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))

        model.compile(
            optimizer=tf.keras.optimizers.Adam(config.learning_rate),
            loss=tf.keras.losses.BinaryCrossentropy(),
            metrics=config.metrics,
        )

        train_target = train_dataset.pop(config.target_column)
        validation_target = validation_dataset.pop(config.target_column)
        model.fit(
            x=train_dataset,
            y=train_target,
            validation_data=(validation_dataset, validation_target),
            batch_size=config.batch_size,
            epochs=config.epochs,
        )
        model.summary()

        return model
CONFIG_CLASS (BaseTrainerConfig) pydantic-model

Config class for the tensorflow trainer

target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
    """Config class for the tensorflow trainer

    target_column: the name of the label column
    layers: the number of units in the fully connected layers
    input_shape: the shape of the input
    learning_rate: the learning rate
    metrics: the list of metrics to be computed
    epochs: the number of epochs
    batch_size: the size of the batch
    """

    target_column: str
    layers: List[int] = [256, 64, 1]
    input_shape: Tuple[int] = (8,)
    learning_rate: float = 0.001
    metrics: List[str] = ["accuracy"]
    epochs: int = 50
    batch_size: int = 8
entrypoint(self, train_dataset, validation_dataset, config)

Main entrypoint for the tensorflow trainer

Parameters:

Name Type Description Default
train_dataset DataFrame

pd.DataFrame, the training dataset

required
validation_dataset DataFrame

pd.DataFrame, the validation dataset

required
config TensorflowBinaryClassifierConfig

the configuration of the step

required

Returns:

Type Description
Model

the trained tf.keras.Model

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
def entrypoint(  # type: ignore[override]
    self,
    train_dataset: pd.DataFrame,
    validation_dataset: pd.DataFrame,
    config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
    """Main entrypoint for the tensorflow trainer

    Args:
        train_dataset: pd.DataFrame, the training dataset
        validation_dataset: pd.DataFrame, the validation dataset
        config: the configuration of the step
    Returns:
        the trained tf.keras.Model
    """
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
    model.add(tf.keras.layers.Flatten())

    last_layer = config.layers.pop()
    for i, layer in enumerate(config.layers):
        model.add(tf.keras.layers.Dense(layer, activation="relu"))
    model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))

    model.compile(
        optimizer=tf.keras.optimizers.Adam(config.learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=config.metrics,
    )

    train_target = train_dataset.pop(config.target_column)
    validation_target = validation_dataset.pop(config.target_column)
    model.fit(
        x=train_dataset,
        y=train_target,
        validation_data=(validation_dataset, validation_target),
        batch_size=config.batch_size,
        epochs=config.epochs,
    )
    model.summary()

    return model
TensorflowBinaryClassifierConfig (BaseTrainerConfig) pydantic-model

Config class for the tensorflow trainer

target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
    """Config class for the tensorflow trainer

    target_column: the name of the label column
    layers: the number of units in the fully connected layers
    input_shape: the shape of the input
    learning_rate: the learning rate
    metrics: the list of metrics to be computed
    epochs: the number of epochs
    batch_size: the size of the batch
    """

    target_column: str
    layers: List[int] = [256, 64, 1]
    input_shape: Tuple[int] = (8,)
    learning_rate: float = 0.001
    metrics: List[str] = ["accuracy"]
    epochs: int = 50
    batch_size: int = 8

utils

get_integration_for_module(module_name)

Gets the integration class for a module inside an integration.

If the module given by module_name is not part of a ZenML integration, this method will return None. If it is part of a ZenML integration, it will return the integration class found inside the integration init file.

Source code in zenml/integrations/utils.py
def get_integration_for_module(module_name: str) -> Optional[Type[Integration]]:
    """Gets the integration class for a module inside an integration.

    If the module given by `module_name` is not part of a ZenML integration,
    this method will return `None`. If it is part of a ZenML integration,
    it will return the integration class found inside the integration
    __init__ file.
    """
    integration_prefix = "zenml.integrations."
    if not module_name.startswith(integration_prefix):
        return None

    integration_module_name = ".".join(module_name.split(".", 3)[:3])
    try:
        integration_module = sys.modules[integration_module_name]
    except KeyError:
        integration_module = importlib.import_module(integration_module_name)

    for name, member in inspect.getmembers(integration_module):
        if (
            member is not Integration
            and isinstance(member, IntegrationMeta)
            and issubclass(member, Integration)
        ):
            return cast(Type[Integration], member)

    return None

get_requirements_for_module(module_name)

Gets requirements for a module inside an integration.

If the module given by module_name is not part of a ZenML integration, this method will return an empty list. If it is part of a ZenML integration, it will return the list of requirements specified inside the integration class found inside the integration init file.

Source code in zenml/integrations/utils.py
def get_requirements_for_module(module_name: str) -> List[str]:
    """Gets requirements for a module inside an integration.

    If the module given by `module_name` is not part of a ZenML integration,
    this method will return an empty list. If it is part of a ZenML integration,
    it will return the list of requirements specified inside the integration
    class found inside the integration __init__ file.
    """
    integration = get_integration_for_module(module_name)
    return integration.REQUIREMENTS if integration else []