Integrations
zenml.integrations
special
The ZenML integrations module contains sub-modules for each integration that we
support. This includes orchestrators like Apache Airflow, visualization tools
like the facets
library, as well as deep learning libraries like PyTorch.
airflow
special
The Airflow integration sub-module powers an alternative to the local
orchestrator. You can enable it by registering the Airflow orchestrator with
the CLI tool, then bootstrap using the zenml orchestrator up
command.
AirflowIntegration (Integration)
Definition of Airflow Integration for ZenML.
Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
"""Definition of Airflow Integration for ZenML."""
NAME = AIRFLOW
REQUIREMENTS = ["apache-airflow==2.2.0"]
@classmethod
def activate(cls):
"""Activates all classes required for the airflow integration."""
from zenml.integrations.airflow import orchestrators # noqa
activate()
classmethod
Activates all classes required for the airflow integration.
Source code in zenml/integrations/airflow/__init__.py
@classmethod
def activate(cls):
"""Activates all classes required for the airflow integration."""
from zenml.integrations.airflow import orchestrators # noqa
orchestrators
special
The Airflow integration enables the use of Airflow as a pipeline orchestrator.
airflow_component
Definition for Airflow component for TFX.
AirflowComponent (PythonOperator)
Airflow-specific TFX Component. This class wrap a component run into its own PythonOperator in Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_component.py
class AirflowComponent(python.PythonOperator):
"""Airflow-specific TFX Component.
This class wrap a component run into its own PythonOperator in Airflow.
"""
def __init__(
self,
*,
parent_dag: airflow.DAG,
pipeline_node: pipeline_pb2.PipelineNode,
mlmd_connection: metadata.Metadata,
pipeline_info: pipeline_pb2.PipelineInfo,
pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec,
executor_spec: Optional[message.Message] = None,
custom_driver_spec: Optional[message.Message] = None,
custom_executor_operators: Optional[
Dict[Any, Type[launcher.ExecutorOperator]]
] = None,
) -> None:
"""Constructs an Airflow implementation of TFX component.
Args:
parent_dag: The airflow DAG that this component is contained in.
pipeline_node: The specification of the node to launch.
mlmd_connection: ML metadata connection info.
pipeline_info: The information of the pipeline that this node
runs in.
pipeline_runtime_spec: The runtime information of the pipeline
that this node runs in.
executor_spec: Specification for the executor of the node.
custom_driver_spec: Specification for custom driver.
custom_executor_operators: Map of executable specs to executor
operators.
"""
launcher_callable = functools.partial(
_airflow_component_launcher,
pipeline_node=pipeline_node,
mlmd_connection=mlmd_connection,
pipeline_info=pipeline_info,
pipeline_runtime_spec=pipeline_runtime_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
super().__init__(
task_id=pipeline_node.node_info.id,
provide_context=True,
python_callable=launcher_callable,
dag=parent_dag,
)
__init__(self, *, parent_dag, pipeline_node, mlmd_connection, pipeline_info, pipeline_runtime_spec, executor_spec=None, custom_driver_spec=None, custom_executor_operators=None)
special
Constructs an Airflow implementation of TFX component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent_dag |
DAG |
The airflow DAG that this component is contained in. |
required |
pipeline_node |
PipelineNode |
The specification of the node to launch. |
required |
mlmd_connection |
Metadata |
ML metadata connection info. |
required |
pipeline_info |
PipelineInfo |
The information of the pipeline that this node runs in. |
required |
pipeline_runtime_spec |
PipelineRuntimeSpec |
The runtime information of the pipeline that this node runs in. |
required |
executor_spec |
Optional[google.protobuf.message.Message] |
Specification for the executor of the node. |
None |
custom_driver_spec |
Optional[google.protobuf.message.Message] |
Specification for custom driver. |
None |
custom_executor_operators |
Optional[Dict[Any, Type[~ExecutorOperator]]] |
Map of executable specs to executor operators. |
None |
Source code in zenml/integrations/airflow/orchestrators/airflow_component.py
def __init__(
self,
*,
parent_dag: airflow.DAG,
pipeline_node: pipeline_pb2.PipelineNode,
mlmd_connection: metadata.Metadata,
pipeline_info: pipeline_pb2.PipelineInfo,
pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec,
executor_spec: Optional[message.Message] = None,
custom_driver_spec: Optional[message.Message] = None,
custom_executor_operators: Optional[
Dict[Any, Type[launcher.ExecutorOperator]]
] = None,
) -> None:
"""Constructs an Airflow implementation of TFX component.
Args:
parent_dag: The airflow DAG that this component is contained in.
pipeline_node: The specification of the node to launch.
mlmd_connection: ML metadata connection info.
pipeline_info: The information of the pipeline that this node
runs in.
pipeline_runtime_spec: The runtime information of the pipeline
that this node runs in.
executor_spec: Specification for the executor of the node.
custom_driver_spec: Specification for custom driver.
custom_executor_operators: Map of executable specs to executor
operators.
"""
launcher_callable = functools.partial(
_airflow_component_launcher,
pipeline_node=pipeline_node,
mlmd_connection=mlmd_connection,
pipeline_info=pipeline_info,
pipeline_runtime_spec=pipeline_runtime_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
super().__init__(
task_id=pipeline_node.node_info.id,
provide_context=True,
python_callable=launcher_callable,
dag=parent_dag,
)
airflow_dag_runner
Definition of Airflow TFX runner. This is an unmodified copy from the TFX source code (outside of superficial, stylistic changes)
AirflowDagRunner
Tfx runner on Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
class AirflowDagRunner:
"""Tfx runner on Airflow."""
def __init__(
self,
config: Optional[Union[Dict[str, Any], AirflowPipelineConfig]] = None,
):
"""Creates an instance of AirflowDagRunner.
Args:
config: Optional Airflow pipeline config for customizing the
launching of each component.
"""
self._config = config or pipeline_config.PipelineConfig()
if isinstance(config, dict):
warnings.warn(
"Pass config as a dict type is going to deprecated in 0.1.16. "
"Use AirflowPipelineConfig type instead.",
PendingDeprecationWarning,
)
self._config = AirflowPipelineConfig(airflow_dag_config=config)
@property
def config(self) -> pipeline_config.PipelineConfig:
return self._config
def run(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> "airflow.DAG":
"""Deploys given logical pipeline on Airflow.
Args:
pipeline: Logical pipeline containing pipeline args and comps.
stack: The current stack that ZenML is running on
runtime_configuration: The configuration of the run
Returns:
An Airflow DAG.
"""
# Only import these when needed.
import airflow # noqa
from zenml.integrations.airflow.orchestrators import airflow_component
# Merge airflow-specific configs with pipeline args
tfx_pipeline = create_tfx_pipeline(pipeline, stack=stack)
if runtime_configuration.schedule:
catchup = runtime_configuration.schedule.catchup
else:
catchup = False
airflow_dag = airflow.DAG(
dag_id=tfx_pipeline.pipeline_info.pipeline_name,
**(
typing.cast(
AirflowPipelineConfig, self._config
).airflow_dag_config
),
is_paused_upon_creation=False,
catchup=catchup,
)
pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
if "tmp_dir" not in tfx_pipeline.additional_pipeline_args:
tmp_dir = os.path.join(pipeline_root, ".temp", "")
tfx_pipeline.additional_pipeline_args["tmp_dir"] = tmp_dir
for component in tfx_pipeline.components:
if isinstance(component, base_component.BaseComponent):
component._resolve_pip_dependencies(pipeline_root)
self._replace_runtime_params(component)
pb2_pipeline: Pb2Pipeline = compiler.Compiler().compile(tfx_pipeline)
# Substitute the runtime parameter to be a concrete run_id
runtime_parameter_utils.substitute_runtime_parameter(
pb2_pipeline,
{
"pipeline-run-id": runtime_configuration.run_name,
},
)
deployment_config = runner_utils.extract_local_deployment_config(
pb2_pipeline
)
connection_config = (
Repository().active_stack.metadata_store.get_tfx_metadata_config()
)
component_impl_map = {}
for node in pb2_pipeline.nodes:
pipeline_node: PipelineNode = node.pipeline_node # type: ignore[valid-type]
# Add the stack as context to each pipeline node:
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.STACK.value,
name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
properties=stack.dict(),
)
# Add all pydantic objects from runtime_configuration to the context
context_utils.add_runtime_configuration_to_node(
pipeline_node, runtime_configuration
)
# Add pipeline requirements as a context
requirements = " ".join(sorted(pipeline.requirements))
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
name=str(hash(requirements)),
properties={"pipeline_requirements": requirements},
)
node_id = pipeline_node.node_info.id
executor_spec = runner_utils.extract_executor_spec(
deployment_config, node_id
)
custom_driver_spec = runner_utils.extract_custom_driver_spec(
deployment_config, node_id
)
step = get_step_for_node(
pipeline_node, steps=list(pipeline.steps.values())
)
custom_executor_operators = {
executable_spec_pb2.PythonClassExecutableSpec: step.executor_operator
}
current_airflow_component = airflow_component.AirflowComponent(
parent_dag=airflow_dag,
pipeline_node=pipeline_node,
mlmd_connection=connection_config,
pipeline_info=pb2_pipeline.pipeline_info,
pipeline_runtime_spec=pb2_pipeline.runtime_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
component_impl_map[node_id] = current_airflow_component
for upstream_node in node.pipeline_node.upstream_nodes:
assert (
upstream_node in component_impl_map
), "Components is not in topological order"
current_airflow_component.set_upstream(
component_impl_map[upstream_node]
)
return airflow_dag
def _replace_runtime_params(
self, comp: base_node.BaseNode
) -> base_node.BaseNode:
"""Replaces runtime params for dynamic Airflow parameter execution.
Args:
comp: TFX component to be parsed.
Returns:
Returns edited component.
"""
for k, prop in comp.exec_properties.copy().items():
if isinstance(prop, RuntimeParameter):
# Airflow only supports string parameters.
if prop.ptype != str:
raise RuntimeError(
f"RuntimeParameter in Airflow does not support "
f"{prop.ptype}. The only ptype supported is string."
)
# If the default is a template, drop the template markers
# when inserting it into the .get() default argument below.
# Otherwise, provide the default as a quoted string.
default = cast(str, prop.default)
if default.startswith("{{") and default.endswith("}}"):
default = default[2:-2]
else:
default = json.dumps(default)
template_field = '{{ dag_run.conf.get("%s", %s) }}' % (
prop.name,
default,
)
comp.exec_properties[k] = template_field
return comp
__init__(self, config=None)
special
Creates an instance of AirflowDagRunner.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[Dict[str, Any], zenml.integrations.airflow.orchestrators.airflow_dag_runner.AirflowPipelineConfig] |
Optional Airflow pipeline config for customizing the launching of each component. |
None |
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
def __init__(
self,
config: Optional[Union[Dict[str, Any], AirflowPipelineConfig]] = None,
):
"""Creates an instance of AirflowDagRunner.
Args:
config: Optional Airflow pipeline config for customizing the
launching of each component.
"""
self._config = config or pipeline_config.PipelineConfig()
if isinstance(config, dict):
warnings.warn(
"Pass config as a dict type is going to deprecated in 0.1.16. "
"Use AirflowPipelineConfig type instead.",
PendingDeprecationWarning,
)
self._config = AirflowPipelineConfig(airflow_dag_config=config)
run(self, pipeline, stack, runtime_configuration)
Deploys given logical pipeline on Airflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
Logical pipeline containing pipeline args and comps. |
required |
stack |
Stack |
The current stack that ZenML is running on |
required |
runtime_configuration |
RuntimeConfiguration |
The configuration of the run |
required |
Returns:
Type | Description |
---|---|
airflow.DAG |
An Airflow DAG. |
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
def run(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> "airflow.DAG":
"""Deploys given logical pipeline on Airflow.
Args:
pipeline: Logical pipeline containing pipeline args and comps.
stack: The current stack that ZenML is running on
runtime_configuration: The configuration of the run
Returns:
An Airflow DAG.
"""
# Only import these when needed.
import airflow # noqa
from zenml.integrations.airflow.orchestrators import airflow_component
# Merge airflow-specific configs with pipeline args
tfx_pipeline = create_tfx_pipeline(pipeline, stack=stack)
if runtime_configuration.schedule:
catchup = runtime_configuration.schedule.catchup
else:
catchup = False
airflow_dag = airflow.DAG(
dag_id=tfx_pipeline.pipeline_info.pipeline_name,
**(
typing.cast(
AirflowPipelineConfig, self._config
).airflow_dag_config
),
is_paused_upon_creation=False,
catchup=catchup,
)
pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
if "tmp_dir" not in tfx_pipeline.additional_pipeline_args:
tmp_dir = os.path.join(pipeline_root, ".temp", "")
tfx_pipeline.additional_pipeline_args["tmp_dir"] = tmp_dir
for component in tfx_pipeline.components:
if isinstance(component, base_component.BaseComponent):
component._resolve_pip_dependencies(pipeline_root)
self._replace_runtime_params(component)
pb2_pipeline: Pb2Pipeline = compiler.Compiler().compile(tfx_pipeline)
# Substitute the runtime parameter to be a concrete run_id
runtime_parameter_utils.substitute_runtime_parameter(
pb2_pipeline,
{
"pipeline-run-id": runtime_configuration.run_name,
},
)
deployment_config = runner_utils.extract_local_deployment_config(
pb2_pipeline
)
connection_config = (
Repository().active_stack.metadata_store.get_tfx_metadata_config()
)
component_impl_map = {}
for node in pb2_pipeline.nodes:
pipeline_node: PipelineNode = node.pipeline_node # type: ignore[valid-type]
# Add the stack as context to each pipeline node:
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.STACK.value,
name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
properties=stack.dict(),
)
# Add all pydantic objects from runtime_configuration to the context
context_utils.add_runtime_configuration_to_node(
pipeline_node, runtime_configuration
)
# Add pipeline requirements as a context
requirements = " ".join(sorted(pipeline.requirements))
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
name=str(hash(requirements)),
properties={"pipeline_requirements": requirements},
)
node_id = pipeline_node.node_info.id
executor_spec = runner_utils.extract_executor_spec(
deployment_config, node_id
)
custom_driver_spec = runner_utils.extract_custom_driver_spec(
deployment_config, node_id
)
step = get_step_for_node(
pipeline_node, steps=list(pipeline.steps.values())
)
custom_executor_operators = {
executable_spec_pb2.PythonClassExecutableSpec: step.executor_operator
}
current_airflow_component = airflow_component.AirflowComponent(
parent_dag=airflow_dag,
pipeline_node=pipeline_node,
mlmd_connection=connection_config,
pipeline_info=pb2_pipeline.pipeline_info,
pipeline_runtime_spec=pb2_pipeline.runtime_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
component_impl_map[node_id] = current_airflow_component
for upstream_node in node.pipeline_node.upstream_nodes:
assert (
upstream_node in component_impl_map
), "Components is not in topological order"
current_airflow_component.set_upstream(
component_impl_map[upstream_node]
)
return airflow_dag
AirflowPipelineConfig (PipelineConfig)
Pipeline config for AirflowDagRunner.
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
class AirflowPipelineConfig(pipeline_config.PipelineConfig):
"""Pipeline config for AirflowDagRunner."""
def __init__(
self,
airflow_dag_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
"""Creates an instance of AirflowPipelineConfig.
Args:
airflow_dag_config: Configs of Airflow DAG model. See
https://airflow.apache.org/_api/airflow/models/dag/index.html#airflow.models.dag.DAG
for the full spec.
**kwargs: keyword args for PipelineConfig.
"""
super().__init__(**kwargs)
self.airflow_dag_config = airflow_dag_config or {}
__init__(self, airflow_dag_config=None, **kwargs)
special
Creates an instance of AirflowPipelineConfig.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
airflow_dag_config |
Optional[Dict[str, Any]] |
Configs of Airflow DAG model. See https://airflow.apache.org/_api/airflow/models/dag/index.html#airflow.models.dag.DAG for the full spec. |
None |
**kwargs |
Any |
keyword args for PipelineConfig. |
{} |
Source code in zenml/integrations/airflow/orchestrators/airflow_dag_runner.py
def __init__(
self,
airflow_dag_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
"""Creates an instance of AirflowPipelineConfig.
Args:
airflow_dag_config: Configs of Airflow DAG model. See
https://airflow.apache.org/_api/airflow/models/dag/index.html#airflow.models.dag.DAG
for the full spec.
**kwargs: keyword args for PipelineConfig.
"""
super().__init__(**kwargs)
self.airflow_dag_config = airflow_dag_config or {}
airflow_orchestrator
AirflowOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines using Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Airflow."""
airflow_home: str = ""
supports_local_execution = True
supports_remote_execution = False
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow."""
super().__init__(**values)
self._set_env()
@property
def flavor(self) -> OrchestratorFlavor:
"""The orchestrator flavor."""
return OrchestratorFlavor.AIRFLOW
@root_validator
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets airflow home according to orchestrator UUID."""
if "uuid" not in values:
raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
values["airflow_home"] = os.path.join(
zenml.io.utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(values["uuid"]),
)
return values
@property
def dags_directory(self) -> str:
"""Returns path to the airflow dags directory."""
return os.path.join(self.airflow_home, "dags")
@property
def pid_file(self) -> str:
"""Returns path to the daemon PID file."""
return os.path.join(self.airflow_home, "airflow_daemon.pid")
@property
def log_file(self) -> str:
"""Returns path to the airflow log file."""
return os.path.join(self.airflow_home, "airflow_orchestrator.log")
@property
def password_file(self) -> str:
"""Returns path to the webserver password file."""
return os.path.join(self.airflow_home, "standalone_admin_password.txt")
def _set_env(self) -> None:
"""Sets environment variables to configure airflow."""
os.environ["AIRFLOW_HOME"] = self.airflow_home
os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
# check the DAG folder every 10 seconds for new files
os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"
def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
"""Copies the DAG module to the airflow DAGs directory if it's not
already located there.
Args:
dag_filepath: Path to the file in which the DAG is defined.
"""
dags_directory = fileio.resolve_relative_path(self.dags_directory)
if dags_directory == os.path.dirname(dag_filepath):
logger.debug("File is already in airflow DAGs directory.")
else:
logger.debug(
"Copying dag file '%s' to DAGs directory.", dag_filepath
)
destination_path = os.path.join(
dags_directory, os.path.basename(dag_filepath)
)
if fileio.file_exists(destination_path):
logger.info(
"File '%s' already exists, overwriting with new DAG file",
destination_path,
)
fileio.copy(dag_filepath, destination_path, overwrite=True)
def _log_webserver_credentials(self) -> None:
"""Logs URL and credentials to log in to the airflow webserver.
Raises:
FileNotFoundError: If the password file does not exist.
"""
if fileio.file_exists(self.password_file):
with open(self.password_file) as file:
password = file.read().strip()
else:
raise FileNotFoundError(
f"Can't find password file '{self.password_file}'"
)
logger.info(
"To inspect your DAGs, login to http://0.0.0.0:8080 "
"with username: admin password: %s",
password,
)
def runtime_options(self) -> Dict[str, Any]:
"""Runtime options for the airflow orchestrator."""
return {DAG_FILEPATH_OPTION_KEY: None}
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Checks whether airflow is running and copies the DAG file to the
airflow DAGs directory.
Raises:
RuntimeError: If airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
try:
dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
@property
def is_running(self) -> bool:
"""Returns whether the airflow daemon is currently running."""
from airflow.cli.commands.standalone_command import StandaloneCommand
from airflow.jobs.triggerer_job import TriggererJob
daemon_running = daemon.check_if_daemon_is_running(self.pid_file)
command = StandaloneCommand()
webserver_port_open = command.port_open(8080)
if not daemon_running:
if webserver_port_open:
raise RuntimeError(
"The airflow daemon does not seem to be running but "
"local port 8080 is occupied. Make sure the port is "
"available and try again."
)
# exit early so we don't check non-existing airflow databases
return False
# we can't use StandaloneCommand().is_ready() here as the
# Airflow SequentialExecutor apparently does not send a heartbeat
# while running a task which would result in this returning `False`
# even if Airflow is running.
airflow_running = webserver_port_open and command.job_running(
TriggererJob
)
return airflow_running
@property
def is_provisioned(self) -> bool:
"""Returns whether the airflow daemon is currently running."""
return self.is_running
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.file_exists(self.dags_directory):
fileio.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=str(Repository().root),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rm_dir(self.airflow_home)
logger.info("Airflow spun down.")
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Schedules a pipeline to be run on Airflow.
Returns:
An Airflow DAG object that corresponds to the ZenML pipeline.
"""
if runtime_configuration.schedule:
airflow_config = {
"schedule_interval": datetime.timedelta(
seconds=runtime_configuration.schedule.interval_second
),
"start_date": runtime_configuration.schedule.start_time,
"end_date": runtime_configuration.schedule.end_time,
}
else:
airflow_config = {
"schedule_interval": "@once",
# Scheduled in the past to make sure it runs immediately
"start_date": datetime.datetime.now() - datetime.timedelta(7),
}
runner = AirflowDagRunner(AirflowPipelineConfig(airflow_config))
return runner.run(
pipeline=pipeline,
stack=stack,
runtime_configuration=runtime_configuration,
)
dags_directory: str
property
readonly
Returns path to the airflow dags directory.
flavor: OrchestratorFlavor
property
readonly
The orchestrator flavor.
is_provisioned: bool
property
readonly
Returns whether the airflow daemon is currently running.
is_running: bool
property
readonly
Returns whether the airflow daemon is currently running.
log_file: str
property
readonly
Returns path to the airflow log file.
password_file: str
property
readonly
Returns path to the webserver password file.
pid_file: str
property
readonly
Returns path to the daemon PID file.
__init__(self, **values)
special
Sets environment variables to configure airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow."""
super().__init__(**values)
self._set_env()
deprovision(self)
Stops the airflow daemon if necessary and tears down resources.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rm_dir(self.airflow_home)
logger.info("Airflow spun down.")
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Checks whether airflow is running and copies the DAG file to the airflow DAGs directory.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If airflow is not running or no DAG filepath runtime option is provided. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Checks whether airflow is running and copies the DAG file to the
airflow DAGs directory.
Raises:
RuntimeError: If airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
try:
dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)
Ensures that Airflow is running.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.file_exists(self.dags_directory):
fileio.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=str(Repository().root),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
run_pipeline(self, pipeline, stack, runtime_configuration)
Schedules a pipeline to be run on Airflow.
Returns:
Type | Description |
---|---|
Any |
An Airflow DAG object that corresponds to the ZenML pipeline. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Schedules a pipeline to be run on Airflow.
Returns:
An Airflow DAG object that corresponds to the ZenML pipeline.
"""
if runtime_configuration.schedule:
airflow_config = {
"schedule_interval": datetime.timedelta(
seconds=runtime_configuration.schedule.interval_second
),
"start_date": runtime_configuration.schedule.start_time,
"end_date": runtime_configuration.schedule.end_time,
}
else:
airflow_config = {
"schedule_interval": "@once",
# Scheduled in the past to make sure it runs immediately
"start_date": datetime.datetime.now() - datetime.timedelta(7),
}
runner = AirflowDagRunner(AirflowPipelineConfig(airflow_config))
return runner.run(
pipeline=pipeline,
stack=stack,
runtime_configuration=runtime_configuration,
)
runtime_options(self)
Runtime options for the airflow orchestrator.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def runtime_options(self) -> Dict[str, Any]:
"""Runtime options for the airflow orchestrator."""
return {DAG_FILEPATH_OPTION_KEY: None}
set_airflow_home(values)
classmethod
Sets airflow home according to orchestrator UUID.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets airflow home according to orchestrator UUID."""
if "uuid" not in values:
raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
values["airflow_home"] = os.path.join(
zenml.io.utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(values["uuid"]),
)
return values
aws
special
The AWS integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores,
and an io
module to handle file operations on S3 buckets.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = ["s3fs==2022.2.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.aws import artifact_stores # noqa
from zenml.integrations.aws import io # noqa
from zenml.integrations.aws import secrets_managers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/aws/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.aws import artifact_stores # noqa
from zenml.integrations.aws import io # noqa
from zenml.integrations.aws import secrets_managers # noqa
artifact_stores
special
s3_artifact_store
S3ArtifactStore (BaseArtifactStore)
pydantic-model
Artifact Store for Amazon S3 based artifacts.
Source code in zenml/integrations/aws/artifact_stores/s3_artifact_store.py
class S3ArtifactStore(BaseArtifactStore):
"""Artifact Store for Amazon S3 based artifacts."""
supports_local_execution = True
supports_remote_execution = True
@property
def flavor(self) -> ArtifactStoreFlavor:
"""The artifact store flavor."""
return ArtifactStoreFlavor.S3
@validator("path")
def ensure_s3_path(cls, path: str) -> str:
"""Ensures that the path is a valid s3 path."""
if not path.startswith("s3://"):
raise ValueError(
f"Path '{path}' specified for S3ArtifactStore is not a "
f"valid s3 path, i.e., starting with `s3://`."
)
return path
flavor: ArtifactStoreFlavor
property
readonly
The artifact store flavor.
ensure_s3_path(path)
classmethod
Ensures that the path is a valid s3 path.
Source code in zenml/integrations/aws/artifact_stores/s3_artifact_store.py
@validator("path")
def ensure_s3_path(cls, path: str) -> str:
"""Ensures that the path is a valid s3 path."""
if not path.startswith("s3://"):
raise ValueError(
f"Path '{path}' specified for S3ArtifactStore is not a "
f"valid s3 path, i.e., starting with `s3://`."
)
return path
io
special
s3_plugin
Plugin which is created to add S3 storage support to ZenML. It inherits from the base Filesystem created by TFX and overwrites the corresponding functions thanks to s3fs.
ZenS3 (Filesystem)
Filesystem that delegates to S3 storage using s3fs.
Note: To allow TFX to check for various error conditions, we need to
raise their custom NotFoundError
instead of the builtin python
FileNotFoundError.
Source code in zenml/integrations/aws/io/s3_plugin.py
class ZenS3(Filesystem):
"""Filesystem that delegates to S3 storage using s3fs.
**Note**: To allow TFX to check for various error conditions, we need to
raise their custom `NotFoundError` instead of the builtin python
FileNotFoundError."""
SUPPORTED_SCHEMES = ["s3://"]
fs: s3fs.S3FileSystem = None
@classmethod
def _ensure_filesystem_set(cls) -> None:
"""Ensures that the filesystem is set."""
if cls.fs is None:
cls.fs = s3fs.S3FileSystem()
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently only
'rb' and 'wb' to read and write binary files are supported.
"""
ZenS3._ensure_filesystem_set()
try:
return ZenS3.fs.open(path=path, mode=mode)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenS3._ensure_filesystem_set()
if not overwrite and ZenS3.fs.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
try:
ZenS3.fs.copy(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def exists(path: PathType) -> bool:
"""Check whether a path exists."""
ZenS3._ensure_filesystem_set()
return ZenS3.fs.exists(path=path) # type: ignore[no-any-return]
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
ZenS3._ensure_filesystem_set()
return [f"s3://{path}" for path in ZenS3.fs.glob(path=pattern)]
@staticmethod
def isdir(path: PathType) -> bool:
"""Check whether a path is a directory."""
ZenS3._ensure_filesystem_set()
return ZenS3.fs.isdir(path=path) # type: ignore[no-any-return]
@staticmethod
def listdir(path: PathType) -> List[PathType]:
"""Return a list of files in a directory."""
ZenS3._ensure_filesystem_set()
# remove s3 prefix if given so we can remove the directory later as
# this method is expected to only return filenames
path = convert_to_str(path)
if path.startswith("s3://"):
path = path[5:]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the S3
filesystem."""
file_path = cast(str, file_dict["Key"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
try:
return [
_extract_basename(dict_)
for dict_ in ZenS3.fs.listdir(path=path)
]
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def makedirs(path: PathType) -> None:
"""Create a directory at the given path. If needed also
create missing parent directories."""
ZenS3._ensure_filesystem_set()
ZenS3.fs.makedirs(path=path, exist_ok=True)
@staticmethod
def mkdir(path: PathType) -> None:
"""Create a directory at the given path."""
ZenS3._ensure_filesystem_set()
ZenS3.fs.makedir(path=path)
@staticmethod
def remove(path: PathType) -> None:
"""Remove the file at the given path."""
ZenS3._ensure_filesystem_set()
try:
ZenS3.fs.rm_file(path=path)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenS3._ensure_filesystem_set()
if not overwrite and ZenS3.fs.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
try:
ZenS3.fs.rename(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def rmtree(path: PathType) -> None:
"""Remove the given directory."""
ZenS3._ensure_filesystem_set()
try:
ZenS3.fs.delete(path=path, recursive=True)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path."""
ZenS3._ensure_filesystem_set()
try:
return ZenS3.fs.stat(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def walk(
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Returns:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
ZenS3._ensure_filesystem_set()
# TODO [ENG-153]: Additional params
for directory, subdirectories, files in ZenS3.fs.walk(path=top):
yield f"s3://{directory}", subdirectories, files
copy(src, dst, overwrite=False)
staticmethod
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the source file does not exist. |
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenS3._ensure_filesystem_set()
if not overwrite and ZenS3.fs.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
try:
ZenS3.fs.copy(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
exists(path)
staticmethod
Check whether a path exists.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def exists(path: PathType) -> bool:
"""Check whether a path exists."""
ZenS3._ensure_filesystem_set()
return ZenS3.fs.exists(path=path) # type: ignore[no-any-return]
glob(pattern)
staticmethod
Return all paths that match the given glob pattern. The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
ZenS3._ensure_filesystem_set()
return [f"s3://{path}" for path in ZenS3.fs.glob(path=pattern)]
isdir(path)
staticmethod
Check whether a path is a directory.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def isdir(path: PathType) -> bool:
"""Check whether a path is a directory."""
ZenS3._ensure_filesystem_set()
return ZenS3.fs.isdir(path=path) # type: ignore[no-any-return]
listdir(path)
staticmethod
Return a list of files in a directory.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def listdir(path: PathType) -> List[PathType]:
"""Return a list of files in a directory."""
ZenS3._ensure_filesystem_set()
# remove s3 prefix if given so we can remove the directory later as
# this method is expected to only return filenames
path = convert_to_str(path)
if path.startswith("s3://"):
path = path[5:]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the S3
filesystem."""
file_path = cast(str, file_dict["Key"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
try:
return [
_extract_basename(dict_)
for dict_ in ZenS3.fs.listdir(path=path)
]
except FileNotFoundError as e:
raise NotFoundError() from e
makedirs(path)
staticmethod
Create a directory at the given path. If needed also create missing parent directories.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def makedirs(path: PathType) -> None:
"""Create a directory at the given path. If needed also
create missing parent directories."""
ZenS3._ensure_filesystem_set()
ZenS3.fs.makedirs(path=path, exist_ok=True)
mkdir(path)
staticmethod
Create a directory at the given path.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def mkdir(path: PathType) -> None:
"""Create a directory at the given path."""
ZenS3._ensure_filesystem_set()
ZenS3.fs.makedir(path=path)
open(path, mode='r')
staticmethod
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently only
'rb' and 'wb' to read and write binary files are supported.
"""
ZenS3._ensure_filesystem_set()
try:
return ZenS3.fs.open(path=path, mode=mode)
except FileNotFoundError as e:
raise NotFoundError() from e
remove(path)
staticmethod
Remove the file at the given path.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def remove(path: PathType) -> None:
"""Remove the file at the given path."""
ZenS3._ensure_filesystem_set()
try:
ZenS3.fs.rm_file(path=path)
except FileNotFoundError as e:
raise NotFoundError() from e
rename(src, dst, overwrite=False)
staticmethod
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the source file does not exist. |
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenS3._ensure_filesystem_set()
if not overwrite and ZenS3.fs.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
try:
ZenS3.fs.rename(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
rmtree(path)
staticmethod
Remove the given directory.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def rmtree(path: PathType) -> None:
"""Remove the given directory."""
ZenS3._ensure_filesystem_set()
try:
ZenS3.fs.delete(path=path, recursive=True)
except FileNotFoundError as e:
raise NotFoundError() from e
stat(path)
staticmethod
Return stat info for the given path.
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path."""
ZenS3._ensure_filesystem_set()
try:
return ZenS3.fs.stat(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
walk(top, topdown=True, onerror=None)
staticmethod
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Returns:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/aws/io/s3_plugin.py
@staticmethod
def walk(
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Returns:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
ZenS3._ensure_filesystem_set()
# TODO [ENG-153]: Additional params
for directory, subdirectories, files in ZenS3.fs.walk(path=top):
yield f"s3://{directory}", subdirectories, files
secret_schemas
special
Secret Schema
...
secrets_managers
special
Secret Manager
...
aws_secrets_manager
AWSSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
supports_local_execution: bool = True
supports_remote_execution: bool = True
region_name: str = DEFAULT_AWS_REGION
CLIENT: ClassVar[Any] = None
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
@property
def flavor(self) -> SecretsManagerFlavor:
"""The secrets manager flavor.
Returns:
The secrets manager flavor."""
return SecretsManagerFlavor.AWS
@property
def type(self) -> StackComponentType:
"""The secrets manager type.
Returns:
The secrets manager type."""
return StackComponentType.SECRETS_MANAGER
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register"""
self._ensure_client_connected(self.region_name)
secret_value = jsonify_secret_contents(secret)
kwargs = {"Name": secret.name, "SecretString": secret_value}
self.CLIENT.create_secret(**kwargs)
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
RuntimeError: if the secret does not exist"""
self._ensure_client_connected(self.region_name)
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=secret_name
)
if "SecretString" not in get_secret_value_response:
raise RuntimeError(f"No secrets found within the {secret_name}")
secret_contents: Dict[str, str] = json.loads(
get_secret_value_response["SecretString"]
)
zenml_schema_name = secret_contents.pop(ZENML_SCHEMA_NAME)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
return secret_schema(**secret_contents)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys."""
self._ensure_client_connected(self.region_name)
# TODO [ENG-720]: Deal with pagination in the aws secret manager when
# listing all secrets
# TODO [ENG-721]: take out this magic maxresults number
response = self.CLIENT.list_secrets(MaxResults=100)
return [secret["Name"] for secret in response["SecretList"]]
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update"""
self._ensure_client_connected(self.region_name)
secret_value = jsonify_secret_contents(secret)
kwargs = {"SecretId": secret.name, "SecretString": secret_value}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete"""
self._ensure_client_connected(self.region_name)
self.CLIENT.delete_secret(
SecretId=secret_name, ForceDeleteWithoutRecovery=False
)
def delete_all_secrets(self, force: bool = False) -> None:
"""Delete all existing secrets.
Args:
force: whether to force delete all secrets"""
self._ensure_client_connected(self.region_name)
for secret_name in self.get_all_secret_keys():
self.CLIENT.delete_secret(
SecretId=secret_name, ForceDeleteWithoutRecovery=force
)
flavor: SecretsManagerFlavor
property
readonly
The secrets manager flavor.
Returns:
Type | Description |
---|---|
SecretsManagerFlavor |
The secrets manager flavor. |
type: StackComponentType
property
readonly
The secrets manager type.
Returns:
Type | Description |
---|---|
StackComponentType |
The secrets manager type. |
delete_all_secrets(self, force=False)
Delete all existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
force |
bool |
whether to force delete all secrets |
False |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self, force: bool = False) -> None:
"""Delete all existing secrets.
Args:
force: whether to force delete all secrets"""
self._ensure_client_connected(self.region_name)
for secret_name in self.get_all_secret_keys():
self.CLIENT.delete_secret(
SecretId=secret_name, ForceDeleteWithoutRecovery=force
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete"""
self._ensure_client_connected(self.region_name)
self.CLIENT.delete_secret(
SecretId=secret_name, ForceDeleteWithoutRecovery=False
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys. |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys."""
self._ensure_client_connected(self.region_name)
# TODO [ENG-720]: Deal with pagination in the aws secret manager when
# listing all secrets
# TODO [ENG-721]: take out this magic maxresults number
response = self.CLIENT.list_secrets(MaxResults=100)
return [secret["Name"] for secret in response["SecretList"]]
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
RuntimeError: if the secret does not exist"""
self._ensure_client_connected(self.region_name)
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=secret_name
)
if "SecretString" not in get_secret_value_response:
raise RuntimeError(f"No secrets found within the {secret_name}")
secret_contents: Dict[str, str] = json.loads(
get_secret_value_response["SecretString"]
)
zenml_schema_name = secret_contents.pop(ZENML_SCHEMA_NAME)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
return secret_schema(**secret_contents)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register"""
self._ensure_client_connected(self.region_name)
secret_value = jsonify_secret_contents(secret)
kwargs = {"Name": secret.name, "SecretString": secret_value}
self.CLIENT.create_secret(**kwargs)
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update"""
self._ensure_client_connected(self.region_name)
secret_value = jsonify_secret_contents(secret)
kwargs = {"SecretId": secret.name, "SecretString": secret_value}
self.CLIENT.put_secret_value(**kwargs)
jsonify_secret_contents(secret)
Adds the secret type to the secret contents to persist the schema type in the secrets backend, so that the correct SecretSchema can be retrieved when the secret is queried from the backend.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
should be a subclass of the BaseSecretSchema class |
required |
Returns:
Type | Description |
---|---|
str |
jsonified dictionary containing all key-value pairs and the ZenML schema type |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def jsonify_secret_contents(secret: BaseSecretSchema) -> str:
"""Adds the secret type to the secret contents to persist the schema
type in the secrets backend, so that the correct SecretSchema can be
retrieved when the secret is queried from the backend.
Args:
secret: should be a subclass of the BaseSecretSchema class
Returns:
jsonified dictionary containing all key-value pairs and the ZenML schema
type
"""
secret_contents = secret.content
secret_contents[ZENML_SCHEMA_NAME] = secret.schema_type
return json.dumps(secret_contents)
azure
special
The Azure integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores,
and an io
module to handle file operations on Azure Blob Storage.
AzureIntegration (Integration)
Definition of Azure integration for ZenML.
Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
"""Definition of Azure integration for ZenML."""
NAME = AZURE
REQUIREMENTS = ["adlfs==2021.10.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.azure import artifact_stores # noqa
from zenml.integrations.azure import io # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/azure/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.azure import artifact_stores # noqa
from zenml.integrations.azure import io # noqa
artifact_stores
special
azure_artifact_store
AzureArtifactStore (BaseArtifactStore)
pydantic-model
Artifact Store for Microsoft Azure based artifacts.
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore):
"""Artifact Store for Microsoft Azure based artifacts."""
supports_local_execution = True
supports_remote_execution = True
@property
def flavor(self) -> ArtifactStoreFlavor:
"""The artifact store flavor."""
return ArtifactStoreFlavor.AZURE
@validator("path")
def ensure_azure_path(cls, path: str) -> str:
"""Ensures that the path is a valid azure path."""
path_prefixes = ["abfs://", "az://"]
if not any(path.startswith(prefix) for prefix in path_prefixes):
raise ValueError(
f"Path '{path}' specified for AzureArtifactStore is not a "
f"valid Azure Blob Storage path, i.e., starting with one of "
f"{path_prefixes}."
)
return path
flavor: ArtifactStoreFlavor
property
readonly
The artifact store flavor.
ensure_azure_path(path)
classmethod
Ensures that the path is a valid azure path.
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
@validator("path")
def ensure_azure_path(cls, path: str) -> str:
"""Ensures that the path is a valid azure path."""
path_prefixes = ["abfs://", "az://"]
if not any(path.startswith(prefix) for prefix in path_prefixes):
raise ValueError(
f"Path '{path}' specified for AzureArtifactStore is not a "
f"valid Azure Blob Storage path, i.e., starting with one of "
f"{path_prefixes}."
)
return path
io
special
azure_plugin
Plugin which is created to add Azure storage support to ZenML. It inherits from the base Filesystem created by TFX and overwrites the corresponding functions thanks to adlfs.
ZenAzure (Filesystem)
Filesystem that delegates to Azure storage using adlfs.
To authenticate with Azure Blob Storage, make sure to set a combination of the following environment variables: - AZURE_STORAGE_CONNECTION_STRING - AZURE_STORAGE_ACCOUNT_NAME and one of [AZURE_STORAGE_ACCOUNT_KEY, AZURE_STORAGE_SAS_TOKEN]
Note: To allow TFX to check for various error conditions, we need to
raise their custom NotFoundError
instead of the builtin python
FileNotFoundError.
Source code in zenml/integrations/azure/io/azure_plugin.py
class ZenAzure(Filesystem):
"""Filesystem that delegates to Azure storage using adlfs.
To authenticate with Azure Blob Storage, make sure to set a
combination of the following environment variables:
- AZURE_STORAGE_CONNECTION_STRING
- AZURE_STORAGE_ACCOUNT_NAME and one of
[AZURE_STORAGE_ACCOUNT_KEY, AZURE_STORAGE_SAS_TOKEN]
**Note**: To allow TFX to check for various error conditions, we need to
raise their custom `NotFoundError` instead of the builtin python
FileNotFoundError."""
SUPPORTED_SCHEMES = ["abfs://", "az://"]
fs: adlfs.AzureBlobFileSystem = None
@classmethod
def _ensure_filesystem_set(cls) -> None:
"""Ensures that the filesystem is set."""
if cls.fs is None:
cls.fs = adlfs.AzureBlobFileSystem(
anon=False,
use_listings_cache=False,
)
@classmethod
def _split_path(cls, path: PathType) -> Tuple[str, str]:
"""Splits a path into the filesystem prefix and remainder.
Example:
```python
prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
print(prefix, remainder) # "az://" "my_container/test.txt"
```
"""
path = convert_to_str(path)
prefix = ""
for potential_prefix in cls.SUPPORTED_SCHEMES:
if path.startswith(potential_prefix):
prefix = potential_prefix
path = path[len(potential_prefix) :]
break
return prefix, path
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently only
'rb' and 'wb' to read and write binary files are supported.
"""
ZenAzure._ensure_filesystem_set()
try:
return ZenAzure.fs.open(path=path, mode=mode)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenAzure._ensure_filesystem_set()
if not overwrite and ZenAzure.fs.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
try:
ZenAzure.fs.copy(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def exists(path: PathType) -> bool:
"""Check whether a path exists."""
ZenAzure._ensure_filesystem_set()
return ZenAzure.fs.exists(path=path) # type: ignore[no-any-return]
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
ZenAzure._ensure_filesystem_set()
prefix, _ = ZenAzure._split_path(pattern)
return [f"{prefix}{path}" for path in ZenAzure.fs.glob(path=pattern)]
@staticmethod
def isdir(path: PathType) -> bool:
"""Check whether a path is a directory."""
ZenAzure._ensure_filesystem_set()
return ZenAzure.fs.isdir(path=path) # type: ignore[no-any-return]
@staticmethod
def listdir(path: PathType) -> List[PathType]:
"""Return a list of files in a directory."""
ZenAzure._ensure_filesystem_set()
_, path = ZenAzure._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the Azure
filesystem."""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
try:
return [
_extract_basename(dict_)
for dict_ in ZenAzure.fs.listdir(path=path)
]
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def makedirs(path: PathType) -> None:
"""Create a directory at the given path. If needed also
create missing parent directories."""
ZenAzure._ensure_filesystem_set()
ZenAzure.fs.makedirs(path=path, exist_ok=True)
@staticmethod
def mkdir(path: PathType) -> None:
"""Create a directory at the given path."""
ZenAzure._ensure_filesystem_set()
ZenAzure.fs.makedir(path=path)
@staticmethod
def remove(path: PathType) -> None:
"""Remove the file at the given path."""
ZenAzure._ensure_filesystem_set()
try:
ZenAzure.fs.rm_file(path=path)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenAzure._ensure_filesystem_set()
if not overwrite and ZenAzure.fs.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
try:
ZenAzure.fs.rename(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def rmtree(path: PathType) -> None:
"""Remove the given directory."""
ZenAzure._ensure_filesystem_set()
try:
ZenAzure.fs.delete(path=path, recursive=True)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path."""
ZenAzure._ensure_filesystem_set()
try:
return ZenAzure.fs.stat(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def walk(
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Returns:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
ZenAzure._ensure_filesystem_set()
# TODO [ENG-153]: Additional params
prefix, _ = ZenAzure._split_path(top)
for directory, subdirectories, files in ZenAzure.fs.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
copy(src, dst, overwrite=False)
staticmethod
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the source file does not exist. |
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenAzure._ensure_filesystem_set()
if not overwrite and ZenAzure.fs.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
try:
ZenAzure.fs.copy(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
exists(path)
staticmethod
Check whether a path exists.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def exists(path: PathType) -> bool:
"""Check whether a path exists."""
ZenAzure._ensure_filesystem_set()
return ZenAzure.fs.exists(path=path) # type: ignore[no-any-return]
glob(pattern)
staticmethod
Return all paths that match the given glob pattern. The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
ZenAzure._ensure_filesystem_set()
prefix, _ = ZenAzure._split_path(pattern)
return [f"{prefix}{path}" for path in ZenAzure.fs.glob(path=pattern)]
isdir(path)
staticmethod
Check whether a path is a directory.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def isdir(path: PathType) -> bool:
"""Check whether a path is a directory."""
ZenAzure._ensure_filesystem_set()
return ZenAzure.fs.isdir(path=path) # type: ignore[no-any-return]
listdir(path)
staticmethod
Return a list of files in a directory.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def listdir(path: PathType) -> List[PathType]:
"""Return a list of files in a directory."""
ZenAzure._ensure_filesystem_set()
_, path = ZenAzure._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the Azure
filesystem."""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
try:
return [
_extract_basename(dict_)
for dict_ in ZenAzure.fs.listdir(path=path)
]
except FileNotFoundError as e:
raise NotFoundError() from e
makedirs(path)
staticmethod
Create a directory at the given path. If needed also create missing parent directories.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def makedirs(path: PathType) -> None:
"""Create a directory at the given path. If needed also
create missing parent directories."""
ZenAzure._ensure_filesystem_set()
ZenAzure.fs.makedirs(path=path, exist_ok=True)
mkdir(path)
staticmethod
Create a directory at the given path.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def mkdir(path: PathType) -> None:
"""Create a directory at the given path."""
ZenAzure._ensure_filesystem_set()
ZenAzure.fs.makedir(path=path)
open(path, mode='r')
staticmethod
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently only
'rb' and 'wb' to read and write binary files are supported.
"""
ZenAzure._ensure_filesystem_set()
try:
return ZenAzure.fs.open(path=path, mode=mode)
except FileNotFoundError as e:
raise NotFoundError() from e
remove(path)
staticmethod
Remove the file at the given path.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def remove(path: PathType) -> None:
"""Remove the file at the given path."""
ZenAzure._ensure_filesystem_set()
try:
ZenAzure.fs.rm_file(path=path)
except FileNotFoundError as e:
raise NotFoundError() from e
rename(src, dst, overwrite=False)
staticmethod
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the source file does not exist. |
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenAzure._ensure_filesystem_set()
if not overwrite and ZenAzure.fs.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
try:
ZenAzure.fs.rename(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
rmtree(path)
staticmethod
Remove the given directory.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def rmtree(path: PathType) -> None:
"""Remove the given directory."""
ZenAzure._ensure_filesystem_set()
try:
ZenAzure.fs.delete(path=path, recursive=True)
except FileNotFoundError as e:
raise NotFoundError() from e
stat(path)
staticmethod
Return stat info for the given path.
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path."""
ZenAzure._ensure_filesystem_set()
try:
return ZenAzure.fs.stat(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
walk(top, topdown=True, onerror=None)
staticmethod
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Returns:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/azure/io/azure_plugin.py
@staticmethod
def walk(
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Returns:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
ZenAzure._ensure_filesystem_set()
# TODO [ENG-153]: Additional params
prefix, _ = ZenAzure._split_path(top)
for directory, subdirectories, files in ZenAzure.fs.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
azureml
special
The AzureML integration submodule provides a way to run ZenML steps in AzureML.
AzureMLIntegration (Integration)
Definition of AzureML integration for ZenML.
Source code in zenml/integrations/azureml/__init__.py
class AzureMLIntegration(Integration):
"""Definition of AzureML integration for ZenML."""
NAME = AZUREML
REQUIREMENTS = ["azureml-core==1.39.0.post1"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.azureml import step_operators # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/azureml/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.azureml import step_operators # noqa
step_operators
special
azureml_step_operator
AzureMLStepOperator (BaseStepOperator)
pydantic-model
Step operator to run a step on AzureML.
This class defines code that can setup an AzureML environment and run the ZenML entrypoint command in it.
Attributes:
Name | Type | Description |
---|---|---|
subscription_id |
str |
The Azure account's subscription ID |
resource_group |
str |
The resource group to which the AzureML workspace is deployed. |
workspace_name |
str |
The name of the AzureML Workspace. |
compute_target_name |
str |
The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already. |
environment_name |
Optional[str] |
[Optional] The name of the environment if there already exists one. |
docker_base_image |
Optional[str] |
[Optional] The custom docker base image that the environment should use. |
tenant_id |
Optional[str] |
The Azure Tenant ID. |
service_principal_id |
Optional[str] |
The ID for the service principal that is created to allow apps to access secure resources. |
service_principal_password |
Optional[str] |
Password for the service principal. |
Source code in zenml/integrations/azureml/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator):
"""Step operator to run a step on AzureML.
This class defines code that can setup an AzureML environment and run the
ZenML entrypoint command in it.
Attributes:
subscription_id: The Azure account's subscription ID
resource_group: The resource group to which the AzureML workspace
is deployed.
workspace_name: The name of the AzureML Workspace.
compute_target_name: The name of the configured ComputeTarget.
An instance of it has to be created on the portal if it doesn't
exist already.
environment_name: [Optional] The name of the environment if there
already exists one.
docker_base_image: [Optional] The custom docker base image that the
environment should use.
tenant_id: The Azure Tenant ID.
service_principal_id: The ID for the service principal that is created
to allow apps to access secure resources.
service_principal_password: Password for the service principal.
"""
supports_local_execution = True
supports_remote_execution = True
subscription_id: str
resource_group: str
workspace_name: str
compute_target_name: str
# Environment
environment_name: Optional[str] = None
docker_base_image: Optional[str] = None
# Service principal authentication
# https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
tenant_id: Optional[str] = None
service_principal_id: Optional[str] = None
service_principal_password: Optional[str] = None
@property
def flavor(self) -> StepOperatorFlavor:
"""The step operator flavor."""
return StepOperatorFlavor.AZUREML
def _get_authentication(self) -> Optional[AbstractAuthentication]:
if (
self.tenant_id
and self.service_principal_id
and self.service_principal_password
):
return ServicePrincipalAuthentication(
tenant_id=self.tenant_id,
service_principal_id=self.service_principal_id,
service_principal_password=self.service_principal_password,
)
return None
def _prepare_environment(
self, workspace: Workspace, requirements: List[str], run_name: str
) -> Environment:
"""Prepares the environment in which Azure will run all jobs.
Args:
workspace: The AzureML Workspace that has configuration
for a storage account, container registry among other
things.
requirements: The list of requirements to be installed
in the environment.
run_name: The name of the pipeline run that can be used
for naming environments and runs.
"""
if self.environment_name:
environment = Environment.get(
workspace=workspace, name=self.environment_name
)
if not environment.python.conda_dependencies:
environment.python.conda_dependencies = (
CondaDependencies.create(
python_version=ZenMLEnvironment.python_version()
)
)
for requirement in requirements:
environment.python.conda_dependencies.add_pip_package(
requirement
)
else:
environment = Environment(name=f"zenml-{run_name}")
environment.python.conda_dependencies = CondaDependencies.create(
pip_packages=requirements,
python_version=ZenMLEnvironment.python_version(),
)
if self.docker_base_image:
# replace the default azure base image
environment.docker.base_image = self.docker_base_image
environment_variables = {
"ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
}
# set credentials to access azure storage
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
environment_variables[key] = value
environment_variables[
ENV_ZENML_CONFIG_PATH
] = f"./{CONTAINER_ZENML_CONFIG_DIR}"
environment.environment_variables = environment_variables
return environment
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Launches a step on AzureML.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
workspace = Workspace.get(
subscription_id=self.subscription_id,
resource_group=self.resource_group,
name=self.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
try:
# Save a copy of the current global configuration with the
# active profile contents into the build context, to have
# the configured stacks accessible from within the Azure ML
# environment.
GlobalConfiguration().copy_config_with_active_profile(
config_path,
load_config_path=f"./{CONTAINER_ZENML_CONFIG_DIR}",
)
environment = self._prepare_environment(
workspace=workspace,
requirements=requirements,
run_name=run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(workspace=workspace, name=pipeline_name)
run = experiment.submit(config=run_config)
finally:
# Clean up the temporary build files
fileio.rm_dir(config_path)
run.display_name = run_name
run.wait_for_completion(show_output=True)
flavor: StepOperatorFlavor
property
readonly
The step operator flavor.
launch(self, pipeline_name, run_name, requirements, entrypoint_command)
Launches a step on AzureML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
Source code in zenml/integrations/azureml/step_operators/azureml_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Launches a step on AzureML.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
workspace = Workspace.get(
subscription_id=self.subscription_id,
resource_group=self.resource_group,
name=self.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
try:
# Save a copy of the current global configuration with the
# active profile contents into the build context, to have
# the configured stacks accessible from within the Azure ML
# environment.
GlobalConfiguration().copy_config_with_active_profile(
config_path,
load_config_path=f"./{CONTAINER_ZENML_CONFIG_DIR}",
)
environment = self._prepare_environment(
workspace=workspace,
requirements=requirements,
run_name=run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(workspace=workspace, name=pipeline_name)
run = experiment.submit(config=run_config)
finally:
# Clean up the temporary build files
fileio.rm_dir(config_path)
run.display_name = run_name
run.wait_for_completion(show_output=True)
dash
special
DashIntegration (Integration)
Definition of Dash integration for ZenML.
Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
"""Definition of Dash integration for ZenML."""
NAME = DASH
REQUIREMENTS = [
"dash>=2.0.0",
"dash-cytoscape>=0.3.0",
"dash-bootstrap-components>=1.0.1",
]
visualizers
special
pipeline_run_lineage_visualizer
PipelineRunLineageVisualizer (BasePipelineRunVisualizer)
Implementation of a lineage diagram via the dash and dash-cyctoscape library.
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BasePipelineRunVisualizer):
"""Implementation of a lineage diagram via the [dash](
https://plotly.com/dash/) and [dash-cyctoscape](
https://dash.plotly.com/cytoscape) library."""
ARTIFACT_PREFIX = "artifact_"
STEP_PREFIX = "step_"
STATUS_CLASS_MAPPING = {
ExecutionStatus.CACHED: "green",
ExecutionStatus.FAILED: "red",
ExecutionStatus.RUNNING: "yellow",
ExecutionStatus.COMPLETED: "blue",
}
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library. The layout
puts every layer of the dag in a column.
"""
app = dash.Dash(
__name__,
external_stylesheets=[
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
],
)
nodes, edges, first_step_id = [], [], None
first_step_id = None
for step in object.steps:
step_output_artifacts = list(step.outputs.values())
execution_id = (
step_output_artifacts[0].producer_step.id
if step_output_artifacts
else step.id
)
step_id = self.STEP_PREFIX + str(step.id)
if first_step_id is None:
first_step_id = step_id
nodes.append(
{
"data": {
"id": step_id,
"execution_id": execution_id,
"label": f"{execution_id} / {step.entrypoint_name}",
"entrypoint_name": step.entrypoint_name, # redundant for consistency
"name": step.name, # redundant for consistency
"type": "step",
"parameters": step.parameters,
"inputs": {k: v.uri for k, v in step.inputs.items()},
"outputs": {k: v.uri for k, v in step.outputs.items()},
},
"classes": self.STATUS_CLASS_MAPPING[step.status],
}
)
for artifact_name, artifact in step.outputs.items():
nodes.append(
{
"data": {
"id": self.ARTIFACT_PREFIX + str(artifact.id),
"execution_id": artifact.id,
"label": f"{artifact.id} / {artifact_name} ("
f"{artifact.data_type})",
"type": "artifact",
"name": artifact_name,
"is_cached": artifact.is_cached,
"artifact_type": artifact.type,
"artifact_data_type": artifact.data_type,
"parent_step_id": artifact.parent_step_id,
"producer_step_id": artifact.producer_step.id,
"uri": artifact.uri,
},
"classes": f"rectangle "
f"{self.STATUS_CLASS_MAPPING[step.status]}",
}
)
edges.append(
{
"data": {
"source": self.STEP_PREFIX + str(step.id),
"target": self.ARTIFACT_PREFIX + str(artifact.id),
},
"classes": f"edge-arrow "
f"{self.STATUS_CLASS_MAPPING[step.status]}"
+ (" dashed" if artifact.is_cached else " solid"),
}
)
for artifact_name, artifact in step.inputs.items():
edges.append(
{
"data": {
"source": self.ARTIFACT_PREFIX + str(artifact.id),
"target": self.STEP_PREFIX + str(step.id),
},
"classes": "edge-arrow "
+ (
f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
if artifact.is_cached
else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
),
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h1"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
if data["type"] == "artifact":
for item in [
"artifact_data_type",
"is_cached",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += "### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Params:"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
app.run_server()
return app
visualize(self, object, *args, **kwargs)
Method to visualize pipeline runs via the Dash library. The layout puts every layer of the dag in a column.
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library. The layout
puts every layer of the dag in a column.
"""
app = dash.Dash(
__name__,
external_stylesheets=[
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
],
)
nodes, edges, first_step_id = [], [], None
first_step_id = None
for step in object.steps:
step_output_artifacts = list(step.outputs.values())
execution_id = (
step_output_artifacts[0].producer_step.id
if step_output_artifacts
else step.id
)
step_id = self.STEP_PREFIX + str(step.id)
if first_step_id is None:
first_step_id = step_id
nodes.append(
{
"data": {
"id": step_id,
"execution_id": execution_id,
"label": f"{execution_id} / {step.entrypoint_name}",
"entrypoint_name": step.entrypoint_name, # redundant for consistency
"name": step.name, # redundant for consistency
"type": "step",
"parameters": step.parameters,
"inputs": {k: v.uri for k, v in step.inputs.items()},
"outputs": {k: v.uri for k, v in step.outputs.items()},
},
"classes": self.STATUS_CLASS_MAPPING[step.status],
}
)
for artifact_name, artifact in step.outputs.items():
nodes.append(
{
"data": {
"id": self.ARTIFACT_PREFIX + str(artifact.id),
"execution_id": artifact.id,
"label": f"{artifact.id} / {artifact_name} ("
f"{artifact.data_type})",
"type": "artifact",
"name": artifact_name,
"is_cached": artifact.is_cached,
"artifact_type": artifact.type,
"artifact_data_type": artifact.data_type,
"parent_step_id": artifact.parent_step_id,
"producer_step_id": artifact.producer_step.id,
"uri": artifact.uri,
},
"classes": f"rectangle "
f"{self.STATUS_CLASS_MAPPING[step.status]}",
}
)
edges.append(
{
"data": {
"source": self.STEP_PREFIX + str(step.id),
"target": self.ARTIFACT_PREFIX + str(artifact.id),
},
"classes": f"edge-arrow "
f"{self.STATUS_CLASS_MAPPING[step.status]}"
+ (" dashed" if artifact.is_cached else " solid"),
}
)
for artifact_name, artifact in step.inputs.items():
edges.append(
{
"data": {
"source": self.ARTIFACT_PREFIX + str(artifact.id),
"target": self.STEP_PREFIX + str(step.id),
},
"classes": "edge-arrow "
+ (
f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
if artifact.is_cached
else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
),
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h1"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
if data["type"] == "artifact":
for item in [
"artifact_data_type",
"is_cached",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += "### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Params:"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
app.run_server()
return app
evidently
special
The Evidently integration provides a way to monitor your models in production. It includes a way to detect data drift and different kinds of model performance issues.
The results of Evidently calculations can either be exported as an interactive dashboard (visualized as an html file or in your Jupyter notebook), or as a JSON file.
EvidentlyIntegration (Integration)
Definition of Evidently integration for ZenML.
Source code in zenml/integrations/evidently/__init__.py
class EvidentlyIntegration(Integration):
"""Definition of [Evidently](https://github.com/evidentlyai/evidently) integration
for ZenML."""
NAME = EVIDENTLY
REQUIREMENTS = ["evidently==v0.1.41.dev0"]
steps
special
evidently_profile
EvidentlyProfileConfig (BaseDriftDetectionConfig)
pydantic-model
Config class for Evidently profile steps.
column_mapping: properties of the dataframe's columns used !!! profile_section "a string that identifies the profile section to be used." The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance"
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
"""Config class for Evidently profile steps.
column_mapping: properties of the dataframe's columns used
profile_section: a string that identifies the profile section to be used.
The following are valid options supported by Evidently:
- "datadrift"
- "categoricaltargetdrift"
- "numericaltargetdrift"
- "classificationmodelperformance"
- "regressionmodelperformance"
- "probabilisticmodelperformance"
"""
def get_profile_sections_and_tabs(
self,
) -> Tuple[List[ProfileSection], List[Tab]]:
try:
return (
[
profile_mapper[profile]()
for profile in self.profile_sections
],
[
dashboard_mapper[profile]()
for profile in self.profile_sections
],
)
except KeyError:
nl = "\n"
raise ValueError(
f"Invalid profile section: {self.profile_sections} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
)
column_mapping: Optional[ColumnMapping]
profile_sections: Sequence[str]
EvidentlyProfileStep (BaseDriftDetectionStep)
Simple step implementation which implements Evidently's functionality for creating a profile.
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileStep(BaseDriftDetectionStep):
"""Simple step implementation which implements Evidently's functionality for
creating a profile."""
OUTPUT_SPEC = {
"profile": DataAnalysisArtifact,
"dashboard": DataAnalysisArtifact,
}
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
comparison_dataset: pd.DataFrame,
config: EvidentlyProfileConfig,
) -> Output( # type:ignore[valid-type]
profile=dict, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection
step.
Args:
reference_dataset: a Pandas dataframe
comparison_dataset: a Pandas dataframe of new data you wish to
compare against the reference data
config: the configuration for the step
Returns:
profile: dictionary report extracted from an Evidently Profile
generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""
sections, tabs = config.get_profile_sections_and_tabs()
data_drift_dashboard = Dashboard(tabs=tabs)
data_drift_dashboard.calculate(
reference_dataset,
comparison_dataset,
column_mapping=config.column_mapping or None,
)
data_drift_profile = Profile(sections=sections)
data_drift_profile.calculate(
reference_dataset,
comparison_dataset,
column_mapping=config.column_mapping or None,
)
return [data_drift_profile.object(), data_drift_dashboard.html()]
CONFIG_CLASS (BaseDriftDetectionConfig)
pydantic-model
Config class for Evidently profile steps.
column_mapping: properties of the dataframe's columns used !!! profile_section "a string that identifies the profile section to be used." The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance"
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
"""Config class for Evidently profile steps.
column_mapping: properties of the dataframe's columns used
profile_section: a string that identifies the profile section to be used.
The following are valid options supported by Evidently:
- "datadrift"
- "categoricaltargetdrift"
- "numericaltargetdrift"
- "classificationmodelperformance"
- "regressionmodelperformance"
- "probabilisticmodelperformance"
"""
def get_profile_sections_and_tabs(
self,
) -> Tuple[List[ProfileSection], List[Tab]]:
try:
return (
[
profile_mapper[profile]()
for profile in self.profile_sections
],
[
dashboard_mapper[profile]()
for profile in self.profile_sections
],
)
except KeyError:
nl = "\n"
raise ValueError(
f"Invalid profile section: {self.profile_sections} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
)
column_mapping: Optional[ColumnMapping]
profile_sections: Sequence[str]
entrypoint(self, reference_dataset, comparison_dataset, config)
Main entrypoint for the Evidently categorical target drift detection step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
a Pandas dataframe |
required |
comparison_dataset |
DataFrame |
a Pandas dataframe of new data you wish to compare against the reference data |
required |
config |
EvidentlyProfileConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
profile |
dictionary report extracted from an Evidently Profile generated for the data drift dashboard: HTML report extracted from an Evidently Dashboard generated for the data drift |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
comparison_dataset: pd.DataFrame,
config: EvidentlyProfileConfig,
) -> Output( # type:ignore[valid-type]
profile=dict, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection
step.
Args:
reference_dataset: a Pandas dataframe
comparison_dataset: a Pandas dataframe of new data you wish to
compare against the reference data
config: the configuration for the step
Returns:
profile: dictionary report extracted from an Evidently Profile
generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""
sections, tabs = config.get_profile_sections_and_tabs()
data_drift_dashboard = Dashboard(tabs=tabs)
data_drift_dashboard.calculate(
reference_dataset,
comparison_dataset,
column_mapping=config.column_mapping or None,
)
data_drift_profile = Profile(sections=sections)
data_drift_profile.calculate(
reference_dataset,
comparison_dataset,
column_mapping=config.column_mapping or None,
)
return [data_drift_profile.object(), data_drift_dashboard.html()]
visualizers
special
evidently_visualizer
EvidentlyVisualizer (BaseStepVisualizer)
The implementation of an Evidently Visualizer.
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
class EvidentlyVisualizer(BaseStepVisualizer):
"""The implementation of an Evidently Visualizer."""
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components
Args:
object: StepView fetched from run.get_step().
"""
for artifact_view in object.outputs.values():
# filter out anything but data analysis artifacts
if (
artifact_view.type == DataAnalysisArtifact.__name__
and artifact_view.data_type == "builtins.str"
):
artifact = artifact_view.read()
self.generate_facet(artifact)
def generate_facet(self, html_: str) -> None:
"""Generate a Facet Overview
Args:
html_: HTML represented as a string.
"""
if Environment.in_notebook():
from IPython.core.display import HTML, display
display(HTML(html_))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_facet(self, html_)
Generate a Facet Overview
Parameters:
Name | Type | Description | Default |
---|---|---|---|
html_ |
str |
HTML represented as a string. |
required |
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
def generate_facet(self, html_: str) -> None:
"""Generate a Facet Overview
Args:
html_: HTML represented as a string.
"""
if Environment.in_notebook():
from IPython.core.display import HTML, display
display(HTML(html_))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)
Method to visualize components
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components
Args:
object: StepView fetched from run.get_step().
"""
for artifact_view in object.outputs.values():
# filter out anything but data analysis artifacts
if (
artifact_view.type == DataAnalysisArtifact.__name__
and artifact_view.data_type == "builtins.str"
):
artifact = artifact_view.read()
self.generate_facet(artifact)
facets
special
The Facets integration provides a simple way to visualize post-execution objects
like PipelineView
, PipelineRunView
and StepView
. These objects can be
extended using the BaseVisualization
class. This integration requires
facets-overview
be installed in your Python environment.
FacetsIntegration (Integration)
Definition of Facet integration for ZenML.
Source code in zenml/integrations/facets/__init__.py
class FacetsIntegration(Integration):
"""Definition of [Facet](https://pair-code.github.io/facets/) integration
for ZenML."""
NAME = FACETS
REQUIREMENTS = ["facets-overview>=1.0.0", "IPython"]
visualizers
special
facet_statistics_visualizer
FacetStatisticsVisualizer (BaseStepVisualizer)
The base implementation of a ZenML Visualizer.
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
class FacetStatisticsVisualizer(BaseStepVisualizer):
"""The base implementation of a ZenML Visualizer."""
@abstractmethod
def visualize(
self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
"""Method to visualize components
Args:
object: StepView fetched from run.get_step().
magic: Whether to render in a Jupyter notebook or not.
"""
datasets = []
for output_name, artifact_view in object.outputs.items():
df = artifact_view.read()
if type(df) is not pd.DataFrame:
logger.warning(
"`%s` is not a pd.DataFrame. You can only visualize "
"statistics of steps that output pandas dataframes. "
"Skipping this output.." % output_name
)
else:
datasets.append({"name": output_name, "table": df})
h = self.generate_html(datasets)
self.generate_facet(h, magic)
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
"""Generates html for facet.
Args:
datasets: List of dicts of dataframes to be visualized as stats.
Returns:
HTML template with proto string embedded.
"""
proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
datasets
)
protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")
template = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"stats.html",
)
html_template = zenml.io.utils.read_file_contents_as_string(template)
html_ = html_template.replace("protostr", protostr)
return html_
def generate_facet(self, html_: str, magic: bool = False) -> None:
"""Generate a Facet Overview
Args:
html_: HTML represented as a string.
magic: Whether to magically materialize facet in a notebook.
"""
if magic:
if not Environment.in_notebook():
raise EnvironmentError(
"The magic functions are only usable in a Jupyter notebook."
)
display(HTML(html_))
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
zenml.io.utils.write_file_contents_as_string(f.name, html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_facet(self, html_, magic=False)
Generate a Facet Overview
Parameters:
Name | Type | Description | Default |
---|---|---|---|
html_ |
str |
HTML represented as a string. |
required |
magic |
bool |
Whether to magically materialize facet in a notebook. |
False |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_facet(self, html_: str, magic: bool = False) -> None:
"""Generate a Facet Overview
Args:
html_: HTML represented as a string.
magic: Whether to magically materialize facet in a notebook.
"""
if magic:
if not Environment.in_notebook():
raise EnvironmentError(
"The magic functions are only usable in a Jupyter notebook."
)
display(HTML(html_))
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
zenml.io.utils.write_file_contents_as_string(f.name, html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_html(self, datasets)
Generates html for facet.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets |
List[Dict[str, pandas.core.frame.DataFrame]] |
List of dicts of dataframes to be visualized as stats. |
required |
Returns:
Type | Description |
---|---|
str |
HTML template with proto string embedded. |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
"""Generates html for facet.
Args:
datasets: List of dicts of dataframes to be visualized as stats.
Returns:
HTML template with proto string embedded.
"""
proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
datasets
)
protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")
template = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"stats.html",
)
html_template = zenml.io.utils.read_file_contents_as_string(template)
html_ = html_template.replace("protostr", protostr)
return html_
visualize(self, object, magic=False, *args, **kwargs)
Method to visualize components
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
magic |
bool |
Whether to render in a Jupyter notebook or not. |
False |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
@abstractmethod
def visualize(
self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
"""Method to visualize components
Args:
object: StepView fetched from run.get_step().
magic: Whether to render in a Jupyter notebook or not.
"""
datasets = []
for output_name, artifact_view in object.outputs.items():
df = artifact_view.read()
if type(df) is not pd.DataFrame:
logger.warning(
"`%s` is not a pd.DataFrame. You can only visualize "
"statistics of steps that output pandas dataframes. "
"Skipping this output.." % output_name
)
else:
datasets.append({"name": output_name, "table": df})
h = self.generate_html(datasets)
self.generate_facet(h, magic)
gcp
special
The GCP integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores, metadata
stores, and an io
module to handle file operations on Google Cloud Storage (GCS).
GcpIntegration (Integration)
Definition of Google Cloud Platform integration for ZenML.
Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
"""Definition of Google Cloud Platform integration for ZenML."""
NAME = GCP
REQUIREMENTS = ["gcsfs"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.gcp import artifact_stores # noqa
from zenml.integrations.gcp import io # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/gcp/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.gcp import artifact_stores # noqa
from zenml.integrations.gcp import io # noqa
artifact_stores
special
gcp_artifact_store
GCPArtifactStore (BaseArtifactStore)
pydantic-model
Artifact Store for Google Cloud Storage based artifacts.
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore):
"""Artifact Store for Google Cloud Storage based artifacts."""
supports_local_execution = True
supports_remote_execution = True
@property
def flavor(self) -> ArtifactStoreFlavor:
"""The artifact store flavor."""
return ArtifactStoreFlavor.GCP
@validator("path")
def ensure_gcs_path(cls, path: str) -> str:
"""Ensures that the path is a valid gcs path."""
if not path.startswith("gs://"):
raise ValueError(
f"Path '{path}' specified for GCPArtifactStore is not a "
f"valid gcs path, i.e., starting with `gs://`."
)
return path
flavor: ArtifactStoreFlavor
property
readonly
The artifact store flavor.
ensure_gcs_path(path)
classmethod
Ensures that the path is a valid gcs path.
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
@validator("path")
def ensure_gcs_path(cls, path: str) -> str:
"""Ensures that the path is a valid gcs path."""
if not path.startswith("gs://"):
raise ValueError(
f"Path '{path}' specified for GCPArtifactStore is not a "
f"valid gcs path, i.e., starting with `gs://`."
)
return path
io
special
gcs_plugin
Plugin which is created to add Google Cloud Store support to ZenML. It inherits from the base Filesystem created by TFX and overwrites the corresponding functions thanks to gcsfs.
ZenGCS (Filesystem)
Filesystem that delegates to Google Cloud Store using gcsfs.
Note: To allow TFX to check for various error conditions, we need to
raise their custom NotFoundError
instead of the builtin python
FileNotFoundError.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
class ZenGCS(Filesystem):
"""Filesystem that delegates to Google Cloud Store using gcsfs.
**Note**: To allow TFX to check for various error conditions, we need to
raise their custom `NotFoundError` instead of the builtin python
FileNotFoundError."""
SUPPORTED_SCHEMES = ["gs://"]
fs: gcsfs.GCSFileSystem = None
@classmethod
def _ensure_filesystem_set(cls) -> None:
"""Ensures that the filesystem is set."""
if ZenGCS.fs is None:
ZenGCS.fs = gcsfs.GCSFileSystem()
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently only
'rb' and 'wb' to read and write binary files are supported.
"""
ZenGCS._ensure_filesystem_set()
try:
return ZenGCS.fs.open(path=path, mode=mode)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenGCS._ensure_filesystem_set()
if not overwrite and ZenGCS.fs.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
try:
ZenGCS.fs.copy(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def exists(path: PathType) -> bool:
"""Check whether a path exists."""
ZenGCS._ensure_filesystem_set()
return ZenGCS.fs.exists(path=path) # type: ignore[no-any-return]
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
ZenGCS._ensure_filesystem_set()
return ZenGCS.fs.glob(path=pattern) # type: ignore[no-any-return]
@staticmethod
def isdir(path: PathType) -> bool:
"""Check whether a path is a directory."""
ZenGCS._ensure_filesystem_set()
return ZenGCS.fs.isdir(path=path) # type: ignore[no-any-return]
@staticmethod
def listdir(path: PathType) -> List[PathType]:
"""Return a list of files in a directory."""
ZenGCS._ensure_filesystem_set()
try:
return ZenGCS.fs.listdir(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def makedirs(path: PathType) -> None:
"""Create a directory at the given path. If needed also
create missing parent directories."""
ZenGCS._ensure_filesystem_set()
ZenGCS.fs.makedirs(path=path, exist_ok=True)
@staticmethod
def mkdir(path: PathType) -> None:
"""Create a directory at the given path."""
ZenGCS._ensure_filesystem_set()
ZenGCS.fs.makedir(path=path)
@staticmethod
def remove(path: PathType) -> None:
"""Remove the file at the given path."""
ZenGCS._ensure_filesystem_set()
try:
ZenGCS.fs.rm_file(path=path)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenGCS._ensure_filesystem_set()
if not overwrite and ZenGCS.fs.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
try:
ZenGCS.fs.rename(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def rmtree(path: PathType) -> None:
"""Remove the given directory."""
ZenGCS._ensure_filesystem_set()
try:
ZenGCS.fs.delete(path=path, recursive=True)
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path."""
ZenGCS._ensure_filesystem_set()
try:
return ZenGCS.fs.stat(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
@staticmethod
def walk(
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Returns:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
ZenGCS._ensure_filesystem_set()
# TODO [ENG-153]: Additional params
return ZenGCS.fs.walk(path=top) # type: ignore[no-any-return]
copy(src, dst, overwrite=False)
staticmethod
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the source file does not exist. |
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def copy(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenGCS._ensure_filesystem_set()
if not overwrite and ZenGCS.fs.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
try:
ZenGCS.fs.copy(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
exists(path)
staticmethod
Check whether a path exists.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def exists(path: PathType) -> bool:
"""Check whether a path exists."""
ZenGCS._ensure_filesystem_set()
return ZenGCS.fs.exists(path=path) # type: ignore[no-any-return]
glob(pattern)
staticmethod
Return all paths that match the given glob pattern. The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def glob(pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
ZenGCS._ensure_filesystem_set()
return ZenGCS.fs.glob(path=pattern) # type: ignore[no-any-return]
isdir(path)
staticmethod
Check whether a path is a directory.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def isdir(path: PathType) -> bool:
"""Check whether a path is a directory."""
ZenGCS._ensure_filesystem_set()
return ZenGCS.fs.isdir(path=path) # type: ignore[no-any-return]
listdir(path)
staticmethod
Return a list of files in a directory.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def listdir(path: PathType) -> List[PathType]:
"""Return a list of files in a directory."""
ZenGCS._ensure_filesystem_set()
try:
return ZenGCS.fs.listdir(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
makedirs(path)
staticmethod
Create a directory at the given path. If needed also create missing parent directories.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def makedirs(path: PathType) -> None:
"""Create a directory at the given path. If needed also
create missing parent directories."""
ZenGCS._ensure_filesystem_set()
ZenGCS.fs.makedirs(path=path, exist_ok=True)
mkdir(path)
staticmethod
Create a directory at the given path.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def mkdir(path: PathType) -> None:
"""Create a directory at the given path."""
ZenGCS._ensure_filesystem_set()
ZenGCS.fs.makedir(path=path)
open(path, mode='r')
staticmethod
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def open(path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently only
'rb' and 'wb' to read and write binary files are supported.
"""
ZenGCS._ensure_filesystem_set()
try:
return ZenGCS.fs.open(path=path, mode=mode)
except FileNotFoundError as e:
raise NotFoundError() from e
remove(path)
staticmethod
Remove the file at the given path.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def remove(path: PathType) -> None:
"""Remove the file at the given path."""
ZenGCS._ensure_filesystem_set()
try:
ZenGCS.fs.rm_file(path=path)
except FileNotFoundError as e:
raise NotFoundError() from e
rename(src, dst, overwrite=False)
staticmethod
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the source file does not exist. |
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def rename(src: PathType, dst: PathType, overwrite: bool = False) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileNotFoundError: If the source file does not exist.
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
ZenGCS._ensure_filesystem_set()
if not overwrite and ZenGCS.fs.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
try:
ZenGCS.fs.rename(path1=src, path2=dst)
except FileNotFoundError as e:
raise NotFoundError() from e
rmtree(path)
staticmethod
Remove the given directory.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def rmtree(path: PathType) -> None:
"""Remove the given directory."""
ZenGCS._ensure_filesystem_set()
try:
ZenGCS.fs.delete(path=path, recursive=True)
except FileNotFoundError as e:
raise NotFoundError() from e
stat(path)
staticmethod
Return stat info for the given path.
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def stat(path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path."""
ZenGCS._ensure_filesystem_set()
try:
return ZenGCS.fs.stat(path=path) # type: ignore[no-any-return]
except FileNotFoundError as e:
raise NotFoundError() from e
walk(top, topdown=True, onerror=None)
staticmethod
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Returns:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/gcp/io/gcs_plugin.py
@staticmethod
def walk(
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Returns:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
ZenGCS._ensure_filesystem_set()
# TODO [ENG-153]: Additional params
return ZenGCS.fs.walk(path=top) # type: ignore[no-any-return]
graphviz
special
GraphvizIntegration (Integration)
Definition of Graphviz integration for ZenML.
Source code in zenml/integrations/graphviz/__init__.py
class GraphvizIntegration(Integration):
"""Definition of Graphviz integration for ZenML."""
NAME = GRAPHVIZ
REQUIREMENTS = ["graphviz>=0.17"]
SYSTEM_REQUIREMENTS = {"graphviz": "dot"}
visualizers
special
pipeline_run_dag_visualizer
PipelineRunDagVisualizer (BasePipelineRunVisualizer)
Visualize the lineage of runs in a pipeline.
Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
class PipelineRunDagVisualizer(BasePipelineRunVisualizer):
"""Visualize the lineage of runs in a pipeline."""
ARTIFACT_DEFAULT_COLOR = "blue"
ARTIFACT_CACHED_COLOR = "green"
ARTIFACT_SHAPE = "box"
ARTIFACT_PREFIX = "artifact_"
STEP_COLOR = "#431D93"
STEP_SHAPE = "ellipse"
STEP_PREFIX = "step_"
FONT = "Roboto"
@abstractmethod
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
"""Creates a pipeline lineage diagram using graphviz."""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
dot = graphviz.Digraph(comment=object.name)
# link the steps together
for step in object.steps:
# add each step as a node
dot.node(
self.STEP_PREFIX + str(step.id),
step.entrypoint_name,
shape=self.STEP_SHAPE,
)
# for each parent of a step, add an edge
for artifact_name, artifact in step.outputs.items():
dot.node(
self.ARTIFACT_PREFIX + str(artifact.id),
f"{artifact_name} \n" f"({artifact._data_type})",
shape=self.ARTIFACT_SHAPE,
)
dot.edge(
self.STEP_PREFIX + str(step.id),
self.ARTIFACT_PREFIX + str(artifact.id),
)
for artifact_name, artifact in step.inputs.items():
dot.edge(
self.ARTIFACT_PREFIX + str(artifact.id),
self.STEP_PREFIX + str(step.id),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
dot.render(filename=f.name, format="png", view=True, cleanup=True)
return dot
visualize(self, object, *args, **kwargs)
Creates a pipeline lineage diagram using graphviz.
Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
@abstractmethod
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
"""Creates a pipeline lineage diagram using graphviz."""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
dot = graphviz.Digraph(comment=object.name)
# link the steps together
for step in object.steps:
# add each step as a node
dot.node(
self.STEP_PREFIX + str(step.id),
step.entrypoint_name,
shape=self.STEP_SHAPE,
)
# for each parent of a step, add an edge
for artifact_name, artifact in step.outputs.items():
dot.node(
self.ARTIFACT_PREFIX + str(artifact.id),
f"{artifact_name} \n" f"({artifact._data_type})",
shape=self.ARTIFACT_SHAPE,
)
dot.edge(
self.STEP_PREFIX + str(step.id),
self.ARTIFACT_PREFIX + str(artifact.id),
)
for artifact_name, artifact in step.inputs.items():
dot.edge(
self.ARTIFACT_PREFIX + str(artifact.id),
self.STEP_PREFIX + str(step.id),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
dot.render(filename=f.name, format="png", view=True, cleanup=True)
return dot
integration
Integration
Base class for integration in ZenML
Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
"""Base class for integration in ZenML"""
NAME = "base_integration"
REQUIREMENTS: List[str] = []
SYSTEM_REQUIREMENTS: Dict[str, str] = {}
@classmethod
def check_installation(cls) -> bool:
"""Method to check whether the required packages are installed"""
try:
for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
result = shutil.which(command)
if result is None:
logger.debug(
"Unable to find the required packages for %s on your "
"system. Please install the packages on your system "
"and try again.",
requirement,
)
return False
for r in cls.REQUIREMENTS:
pkg_resources.get_distribution(r)
logger.debug(
f"Integration {cls.NAME} is installed correctly with "
f"requirements {cls.REQUIREMENTS}."
)
return True
except pkg_resources.DistributionNotFound as e:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
logger.debug(
f"VersionConflict error when loading installation {cls.NAME}: "
f"{str(e)}"
)
return False
@staticmethod
def activate() -> None:
"""Abstract method to activate the integration"""
activate()
staticmethod
Abstract method to activate the integration
Source code in zenml/integrations/integration.py
@staticmethod
def activate() -> None:
"""Abstract method to activate the integration"""
check_installation()
classmethod
Method to check whether the required packages are installed
Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
"""Method to check whether the required packages are installed"""
try:
for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
result = shutil.which(command)
if result is None:
logger.debug(
"Unable to find the required packages for %s on your "
"system. Please install the packages on your system "
"and try again.",
requirement,
)
return False
for r in cls.REQUIREMENTS:
pkg_resources.get_distribution(r)
logger.debug(
f"Integration {cls.NAME} is installed correctly with "
f"requirements {cls.REQUIREMENTS}."
)
return True
except pkg_resources.DistributionNotFound as e:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
logger.debug(
f"VersionConflict error when loading installation {cls.NAME}: "
f"{str(e)}"
)
return False
IntegrationMeta (type)
Metaclass responsible for registering different Integration subclasses
Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
"""Metaclass responsible for registering different Integration
subclasses"""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
"""Hook into creation of an Integration class."""
cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
if name != "Integration":
integration_registry.register_integration(cls.NAME, cls)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Hook into creation of an Integration class.
Source code in zenml/integrations/integration.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
"""Hook into creation of an Integration class."""
cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
if name != "Integration":
integration_registry.register_integration(cls.NAME, cls)
return cls
kubeflow
special
The Kubeflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool.
KubeflowIntegration (Integration)
Definition of Kubeflow Integration for ZenML.
Source code in zenml/integrations/kubeflow/__init__.py
class KubeflowIntegration(Integration):
"""Definition of Kubeflow Integration for ZenML."""
NAME = KUBEFLOW
REQUIREMENTS = ["kfp==1.8.9"]
@classmethod
def activate(cls) -> None:
"""Activates all classes required for the airflow integration."""
from zenml.integrations.kubeflow import metadata_stores # noqa
from zenml.integrations.kubeflow import orchestrators # noqa
activate()
classmethod
Activates all classes required for the airflow integration.
Source code in zenml/integrations/kubeflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates all classes required for the airflow integration."""
from zenml.integrations.kubeflow import metadata_stores # noqa
from zenml.integrations.kubeflow import orchestrators # noqa
container_entrypoint
Main entrypoint for containers with Kubeflow TFX component executors.
main()
Runs a single step defined by the command line arguments.
Source code in zenml/integrations/kubeflow/container_entrypoint.py
def main() -> None:
"""Runs a single step defined by the command line arguments."""
# Log to the container's stdout so Kubeflow Pipelines UI can display logs to
# the user.
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)
args = _parse_command_line_arguments()
tfx_pipeline = pipeline_pb2.Pipeline()
json_format.Parse(args.tfx_ir, tfx_pipeline)
run_name = _get_run_name()
_resolve_runtime_parameters(tfx_pipeline, run_name, args.runtime_parameter)
node_id = args.node_id
pipeline_node = _get_pipeline_node(tfx_pipeline, node_id)
deployment_config = runner_utils.extract_local_deployment_config(
tfx_pipeline
)
executor_spec = runner_utils.extract_executor_spec(
deployment_config, node_id
)
custom_driver_spec = runner_utils.extract_custom_driver_spec(
deployment_config, node_id
)
# make sure all integrations are activated so all materializers etc. are
# available
integration_registry.activate_integrations()
repo = Repository()
metadata_store = repo.active_stack.metadata_store
metadata_connection = metadata.Metadata(
metadata_store.get_tfx_metadata_config()
)
# import the user main module to register all the materializers
importlib.import_module(args.main_module)
zenml.constants.USER_MAIN_MODULE = args.main_module
step_module = importlib.import_module(args.step_module)
step_class = getattr(step_module, args.step_function_name)
step_instance = cast(BaseStep, step_class())
if hasattr(executor_spec, "class_path"):
executor_module_parts = getattr(executor_spec, "class_path").split(".")
executor_class_target_module_name = ".".join(executor_module_parts[:-1])
_create_executor_class(
step=step_instance,
executor_class_target_module_name=executor_class_target_module_name,
input_artifact_type_mapping=json.loads(args.input_artifact_types),
)
else:
raise RuntimeError(
f"No class path found inside executor spec: {executor_spec}."
)
custom_executor_operators = {
executable_spec_pb2.PythonClassExecutableSpec: step_instance.executor_operator
}
component_launcher = launcher.Launcher(
pipeline_node=pipeline_node,
mlmd_connection=metadata_connection,
pipeline_info=tfx_pipeline.pipeline_info,
pipeline_runtime_spec=tfx_pipeline.runtime_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
execution_info = execute_step(component_launcher)
if execution_info:
_dump_ui_metadata(pipeline_node, execution_info, args.metadata_ui_path)
metadata_stores
special
kubeflow_metadata_store
KubeflowMetadataStore (MySQLMetadataStore)
pydantic-model
Kubeflow MySQL backend for ZenML metadata store.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
class KubeflowMetadataStore(MySQLMetadataStore):
"""Kubeflow MySQL backend for ZenML metadata store."""
host: str = "127.0.0.1"
port: int = 3306
database: str = "metadb"
username: str = "root"
password: str = ""
@property
def flavor(self) -> MetadataStoreFlavor:
"""The metadata store flavor."""
return MetadataStoreFlavor.KUBEFLOW
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for the kubeflow metadata store."""
if inside_kfp_pod():
connection_config = metadata_store_pb2.MetadataStoreClientConfig()
connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
connection_config.port = int(
os.environ["METADATA_GRPC_SERVICE_PORT"]
)
return connection_config
else:
return super().get_tfx_metadata_config()
flavor: MetadataStoreFlavor
property
readonly
The metadata store flavor.
get_tfx_metadata_config(self)
Return tfx metadata config for the kubeflow metadata store.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for the kubeflow metadata store."""
if inside_kfp_pod():
connection_config = metadata_store_pb2.MetadataStoreClientConfig()
connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
connection_config.port = int(
os.environ["METADATA_GRPC_SERVICE_PORT"]
)
return connection_config
else:
return super().get_tfx_metadata_config()
inside_kfp_pod()
Returns if the current python process is running inside a KFP Pod.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def inside_kfp_pod() -> bool:
"""Returns if the current python process is running inside a KFP Pod."""
if "KFP_POD_NAME" not in os.environ:
return False
try:
k8s_config.load_incluster_config()
return True
except k8s_config.ConfigException:
return False
orchestrators
special
kubeflow_component
Kubeflow Pipelines based implementation of TFX components. These components are lightweight wrappers around the KFP DSL's ContainerOp, and ensure that the container gets called with the right set of input arguments. It also ensures that each component exports named output attributes that are consistent with those provided by the native TFX components, thus ensuring that both types of pipeline definitions are compatible. Note: This requires Kubeflow Pipelines SDK to be installed.
KubeflowComponent
Base component for all Kubeflow pipelines TFX components. Returns a wrapper around a KFP DSL ContainerOp class, and adds named output attributes that match the output names for the corresponding native TFX components.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_component.py
class KubeflowComponent:
"""Base component for all Kubeflow pipelines TFX components.
Returns a wrapper around a KFP DSL ContainerOp class, and adds named output
attributes that match the output names for the corresponding native TFX
components.
"""
def __init__(
self,
component: tfx_base_component.BaseComponent,
depends_on: Set[dsl.ContainerOp],
image: str,
tfx_ir: pipeline_pb2.Pipeline,
pod_labels_to_attach: Dict[str, str],
main_module: str,
step_module: str,
step_function_name: str,
runtime_parameters: List[data_types.RuntimeParameter],
):
"""Creates a new Kubeflow-based component.
This class essentially wraps a dsl.ContainerOp construct in Kubeflow
Pipelines.
Args:
component: The logical TFX component to wrap.
depends_on: The set of upstream KFP ContainerOp components that this
component will depend on.
image: The container image to use for this component.
tfx_ir: The TFX intermedia representation of the pipeline.
pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
runtime_parameters: Runtime parameters of the pipeline.
"""
# Path to a metadata file that will be displayed in the KFP UI
# This metadata file needs to be in a mounted emptyDir to avoid
# sporadic failures with the (not mature) PNS executor
# See these links for more information about limitations of PNS +
# security context:
# https://www.kubeflow.org/docs/components/pipelines/installation/localcluster-deployment/#deploying-kubeflow-pipelines
# https://argoproj.github.io/argo-workflows/empty-dir/
# KFP will switch to the Emissary executor (soon), when this emptyDir
# mount will not be necessary anymore, but for now it's still in alpha
# status (https://www.kubeflow.org/docs/components/pipelines/installation/choose-executor/#emissary-executor)
metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
volumes: Dict[str, k8s_client.V1Volume] = {
"/outputs": k8s_client.V1Volume(
name="outputs", empty_dir=k8s_client.V1EmptyDirVolumeSource()
),
}
utils.replace_placeholder(component)
input_artifact_type_mapping = _get_input_artifact_type_mapping(
component
)
arguments = [
"--node_id",
component.id,
"--tfx_ir",
json_format.MessageToJson(tfx_ir),
"--metadata_ui_path",
metadata_ui_path,
"--main_module",
main_module,
"--step_module",
step_module,
"--step_function_name",
step_function_name,
"--input_artifact_types",
json.dumps(input_artifact_type_mapping),
]
for param in runtime_parameters:
arguments.append("--runtime_parameter")
arguments.append(_encode_runtime_parameter(param))
stack = Repository().active_stack
artifact_store = stack.artifact_store
metadata_store = stack.metadata_store
has_local_repos = False
if isinstance(artifact_store, LocalArtifactStore):
has_local_repos = True
host_path = k8s_client.V1HostPathVolumeSource(
path=artifact_store.path, type="Directory"
)
volumes[artifact_store.path] = k8s_client.V1Volume(
name="local-artifact-store", host_path=host_path
)
logger.debug(
"Adding host path volume for local artifact store (path: %s) "
"in kubeflow pipelines container.",
artifact_store.path,
)
if isinstance(metadata_store, SQLiteMetadataStore):
has_local_repos = True
metadata_store_dir = os.path.dirname(metadata_store.uri)
host_path = k8s_client.V1HostPathVolumeSource(
path=metadata_store_dir, type="Directory"
)
volumes[metadata_store_dir] = k8s_client.V1Volume(
name="local-metadata-store", host_path=host_path
)
logger.debug(
"Adding host path volume for local metadata store (uri: %s) "
"in kubeflow pipelines container.",
metadata_store.uri,
)
self.container_op = dsl.ContainerOp(
name=component.id,
command=CONTAINER_ENTRYPOINT_COMMAND,
image=image,
arguments=arguments,
output_artifact_paths={
"mlpipeline-ui-metadata": metadata_ui_path,
},
pvolumes=volumes,
)
if has_local_repos:
if sys.platform == "win32":
# File permissions are not checked on Windows. This if clause
# prevents mypy from complaining about unused 'type: ignore'
# statements
pass
else:
# Run KFP containers in the context of the local UID/GID
# to ensure that the artifact and metadata stores can be shared
# with the local pipeline runs.
self.container_op.container.security_context = (
k8s_client.V1SecurityContext(
run_as_user=os.getuid(),
run_as_group=os.getgid(),
)
)
logger.debug(
"Setting security context UID and GID to local user/group "
"in kubeflow pipelines container."
)
for op in depends_on:
self.container_op.after(op)
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=ENV_ZENML_PREVENT_PIPELINE_EXECUTION, value="True"
)
)
# Add environment variables for Azure Blob Storage to pod in case they
# are set locally
# TODO [ENG-699]: remove this as soon as we implement credential handling
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(name=key, value=value)
)
for k, v in pod_labels_to_attach.items():
self.container_op.add_pod_label(k, v)
__init__(self, component, depends_on, image, tfx_ir, pod_labels_to_attach, main_module, step_module, step_function_name, runtime_parameters)
special
Creates a new Kubeflow-based component. This class essentially wraps a dsl.ContainerOp construct in Kubeflow Pipelines.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component |
BaseComponent |
The logical TFX component to wrap. |
required |
depends_on |
Set[kfp.dsl._container_op.ContainerOp] |
The set of upstream KFP ContainerOp components that this component will depend on. |
required |
image |
str |
The container image to use for this component. |
required |
tfx_ir |
Pipeline |
The TFX intermedia representation of the pipeline. |
required |
pod_labels_to_attach |
Dict[str, str] |
Dict of pod labels to attach to the GKE pod. |
required |
runtime_parameters |
List[tfx.orchestration.data_types.RuntimeParameter] |
Runtime parameters of the pipeline. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_component.py
def __init__(
self,
component: tfx_base_component.BaseComponent,
depends_on: Set[dsl.ContainerOp],
image: str,
tfx_ir: pipeline_pb2.Pipeline,
pod_labels_to_attach: Dict[str, str],
main_module: str,
step_module: str,
step_function_name: str,
runtime_parameters: List[data_types.RuntimeParameter],
):
"""Creates a new Kubeflow-based component.
This class essentially wraps a dsl.ContainerOp construct in Kubeflow
Pipelines.
Args:
component: The logical TFX component to wrap.
depends_on: The set of upstream KFP ContainerOp components that this
component will depend on.
image: The container image to use for this component.
tfx_ir: The TFX intermedia representation of the pipeline.
pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
runtime_parameters: Runtime parameters of the pipeline.
"""
# Path to a metadata file that will be displayed in the KFP UI
# This metadata file needs to be in a mounted emptyDir to avoid
# sporadic failures with the (not mature) PNS executor
# See these links for more information about limitations of PNS +
# security context:
# https://www.kubeflow.org/docs/components/pipelines/installation/localcluster-deployment/#deploying-kubeflow-pipelines
# https://argoproj.github.io/argo-workflows/empty-dir/
# KFP will switch to the Emissary executor (soon), when this emptyDir
# mount will not be necessary anymore, but for now it's still in alpha
# status (https://www.kubeflow.org/docs/components/pipelines/installation/choose-executor/#emissary-executor)
metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
volumes: Dict[str, k8s_client.V1Volume] = {
"/outputs": k8s_client.V1Volume(
name="outputs", empty_dir=k8s_client.V1EmptyDirVolumeSource()
),
}
utils.replace_placeholder(component)
input_artifact_type_mapping = _get_input_artifact_type_mapping(
component
)
arguments = [
"--node_id",
component.id,
"--tfx_ir",
json_format.MessageToJson(tfx_ir),
"--metadata_ui_path",
metadata_ui_path,
"--main_module",
main_module,
"--step_module",
step_module,
"--step_function_name",
step_function_name,
"--input_artifact_types",
json.dumps(input_artifact_type_mapping),
]
for param in runtime_parameters:
arguments.append("--runtime_parameter")
arguments.append(_encode_runtime_parameter(param))
stack = Repository().active_stack
artifact_store = stack.artifact_store
metadata_store = stack.metadata_store
has_local_repos = False
if isinstance(artifact_store, LocalArtifactStore):
has_local_repos = True
host_path = k8s_client.V1HostPathVolumeSource(
path=artifact_store.path, type="Directory"
)
volumes[artifact_store.path] = k8s_client.V1Volume(
name="local-artifact-store", host_path=host_path
)
logger.debug(
"Adding host path volume for local artifact store (path: %s) "
"in kubeflow pipelines container.",
artifact_store.path,
)
if isinstance(metadata_store, SQLiteMetadataStore):
has_local_repos = True
metadata_store_dir = os.path.dirname(metadata_store.uri)
host_path = k8s_client.V1HostPathVolumeSource(
path=metadata_store_dir, type="Directory"
)
volumes[metadata_store_dir] = k8s_client.V1Volume(
name="local-metadata-store", host_path=host_path
)
logger.debug(
"Adding host path volume for local metadata store (uri: %s) "
"in kubeflow pipelines container.",
metadata_store.uri,
)
self.container_op = dsl.ContainerOp(
name=component.id,
command=CONTAINER_ENTRYPOINT_COMMAND,
image=image,
arguments=arguments,
output_artifact_paths={
"mlpipeline-ui-metadata": metadata_ui_path,
},
pvolumes=volumes,
)
if has_local_repos:
if sys.platform == "win32":
# File permissions are not checked on Windows. This if clause
# prevents mypy from complaining about unused 'type: ignore'
# statements
pass
else:
# Run KFP containers in the context of the local UID/GID
# to ensure that the artifact and metadata stores can be shared
# with the local pipeline runs.
self.container_op.container.security_context = (
k8s_client.V1SecurityContext(
run_as_user=os.getuid(),
run_as_group=os.getgid(),
)
)
logger.debug(
"Setting security context UID and GID to local user/group "
"in kubeflow pipelines container."
)
for op in depends_on:
self.container_op.after(op)
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=ENV_ZENML_PREVENT_PIPELINE_EXECUTION, value="True"
)
)
# Add environment variables for Azure Blob Storage to pod in case they
# are set locally
# TODO [ENG-699]: remove this as soon as we implement credential handling
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(name=key, value=value)
)
for k, v in pod_labels_to_attach.items():
self.container_op.add_pod_label(k, v)
kubeflow_dag_runner
The below code is copied from the TFX source repo with minor changes. All credits go to the TFX team for the core implementation
KubeflowDagRunner
Kubeflow Pipelines runner. Constructs a pipeline definition YAML file based on the TFX logical pipeline.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
class KubeflowDagRunner:
"""Kubeflow Pipelines runner. Constructs a pipeline definition YAML file
based on the TFX logical pipeline.
"""
def __init__(
self,
config: KubeflowDagRunnerConfig,
output_path: str,
pod_labels_to_attach: Optional[Dict[str, str]] = None,
):
"""Initializes KubeflowDagRunner for compiling a Kubeflow Pipeline.
Args:
config: A KubeflowDagRunnerConfig object to specify runtime
configuration when running the pipeline under Kubeflow.
output_path: Path where the pipeline definition file will be stored.
pod_labels_to_attach: Optional set of pod labels to attach to GKE pod
spun up for this pipeline. Default to the 3 labels:
1. add-pod-env: true,
2. pipeline SDK type,
3. pipeline unique ID,
where 2 and 3 are instrumentation of usage tracking.
"""
self._config = config or pipeline_config.PipelineConfig()
self._kubeflow_config = config
self._output_path = output_path
self._compiler = compiler.Compiler()
self._tfx_compiler = tfx_compiler.Compiler()
self._params: List[dsl.PipelineParam] = []
self._params_by_component_id: Dict[
str, List[data_types.RuntimeParameter]
] = collections.defaultdict(list)
self._deduped_parameter_names: Set[str] = set()
self._pod_labels_to_attach = (
pod_labels_to_attach or get_default_pod_labels()
)
@property
def config(self) -> pipeline_config.PipelineConfig:
"""The config property"""
return self._config
def _parse_parameter_from_component(
self, component: tfx_base_component.BaseComponent
) -> None:
"""Extract embedded RuntimeParameter placeholders from a component.
Extract embedded RuntimeParameter placeholders from a component, then
append the corresponding dsl.PipelineParam to KubeflowDagRunner.
Args:
component: a TFX component.
"""
deduped_parameter_names_for_component = set()
for parameter in component.exec_properties.values():
if not isinstance(parameter, data_types.RuntimeParameter):
continue
# Ignore pipeline root because it will be added later.
if parameter.name == TFX_PIPELINE_ROOT_PARAMETER.name:
continue
if parameter.name in deduped_parameter_names_for_component:
continue
deduped_parameter_names_for_component.add(parameter.name)
self._params_by_component_id[component.id].append(parameter)
if parameter.name not in self._deduped_parameter_names:
self._deduped_parameter_names.add(parameter.name)
dsl_parameter = dsl.PipelineParam(
name=parameter.name, value=str(parameter.default)
)
self._params.append(dsl_parameter)
def _parse_parameter_from_pipeline(self, pipeline: TfxPipeline) -> None:
"""Extract all the RuntimeParameter placeholders from the pipeline."""
for component in pipeline.components:
self._parse_parameter_from_component(component)
def _construct_pipeline_graph(
self,
pipeline: "BasePipeline",
tfx_pipeline: TfxPipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Constructs a Kubeflow Pipeline graph.
Args:
pipeline: ZenML pipeline instance.
tfx_pipeline: The logical TFX pipeline to base the construction on.
stack: The ZenML stack that the pipeline is running on
runtime_configuration: The runtime configuration
"""
component_to_kfp_op: Dict[base_node.BaseNode, dsl.ContainerOp] = {}
tfx_ir: Pb2Pipeline = self._generate_tfx_ir(tfx_pipeline)
for node in tfx_ir.nodes:
pipeline_node: PipelineNode = node.pipeline_node
# Add the stack as context to each pipeline node:
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.STACK.value,
name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
properties=stack.dict(),
)
# Add all pydantic objects from runtime_configuration to the
# context
context_utils.add_runtime_configuration_to_node(
pipeline_node, runtime_configuration
)
# Add pipeline requirements as a context
requirements = " ".join(sorted(pipeline.requirements))
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
name=str(hash(requirements)),
properties={"pipeline_requirements": requirements},
)
# Assumption: There is a partial ordering of components in the list,
# i.e. if component A depends on component B and C, then A appears
# after B and C in the list.
for component in tfx_pipeline.components:
# Keep track of the set of upstream dsl.ContainerOps for this
# component.
depends_on = set()
for upstream_component in component.upstream_nodes:
depends_on.add(component_to_kfp_op[upstream_component])
# remove the extra pipeline node information
tfx_node_ir = self._dehydrate_tfx_ir(tfx_ir, component.id)
main_module = get_module_source_from_module(sys.modules["__main__"])
step_module = component.component_type.split(".")[:-1]
if step_module[0] == "__main__":
step_module = main_module
else:
step_module = ".".join(step_module)
kfp_component = KubeflowComponent(
main_module=main_module,
step_module=step_module,
step_function_name=component.id,
component=component,
depends_on=depends_on,
image=self._kubeflow_config.image,
pod_labels_to_attach=self._pod_labels_to_attach,
tfx_ir=tfx_node_ir,
runtime_parameters=self._params_by_component_id[component.id],
)
for operator in self._kubeflow_config.pipeline_operator_funcs:
kfp_component.container_op.apply(operator)
component_to_kfp_op[component] = kfp_component.container_op
def _del_unused_field(
self, node_id: str, message_dict: MutableMapping[str, Any]
) -> None:
"""Remove fields that are not used by the pipeline."""
for item in list(message_dict.keys()):
if item != node_id:
del message_dict[item]
def _dehydrate_tfx_ir(
self, original_pipeline: Pb2Pipeline, node_id: str
) -> Pb2Pipeline:
"""Dehydrate the TFX IR to remove unused fields."""
pipeline = copy.deepcopy(original_pipeline)
for node in pipeline.nodes:
if (
node.WhichOneof("node") == "pipeline_node"
and node.pipeline_node.node_info.id == node_id
):
del pipeline.nodes[:]
pipeline.nodes.extend([node])
break
deployment_config = IntermediateDeploymentConfig()
pipeline.deployment_config.Unpack(deployment_config)
self._del_unused_field(node_id, deployment_config.executor_specs)
self._del_unused_field(node_id, deployment_config.custom_driver_specs)
self._del_unused_field(
node_id, deployment_config.node_level_platform_configs
)
pipeline.deployment_config.Pack(deployment_config)
return pipeline
def _generate_tfx_ir(self, pipeline: TfxPipeline) -> Pb2Pipeline:
"""Generate the TFX IR from the logical TFX pipeline."""
result = self._tfx_compiler.compile(pipeline)
return result
def run(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Compiles and outputs a Kubeflow Pipeline YAML definition file.
Args:
pipeline: The logical TFX pipeline to use when building the Kubeflow
pipeline.
stack: The ZenML stack that the pipeline is running on.
runtime_configuration: The runtime configuration.
"""
tfx_pipeline = create_tfx_pipeline(pipeline, stack=stack)
pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
if not isinstance(pipeline_root, str):
raise TypeError(
"TFX Pipeline root may not be a Placeholder, "
"but must be a specific string."
)
for component in tfx_pipeline.components:
# TODO(b/187122662): Pass through pip dependencies as a first-class
# component flag.
if isinstance(component, tfx_base_component.BaseComponent):
component._resolve_pip_dependencies(pipeline_root)
def _construct_pipeline() -> None:
"""Creates Kubeflow ContainerOps for each TFX component
encountered in the pipeline definition."""
self._construct_pipeline_graph(
pipeline, tfx_pipeline, stack, runtime_configuration
)
# Need to run this first to get self._params populated. Then KFP
# compiler can correctly match default value with PipelineParam.
self._parse_parameter_from_pipeline(tfx_pipeline)
# Create workflow spec and write out to package.
self._compiler._create_and_write_workflow(
# pylint: disable=protected-access
pipeline_func=_construct_pipeline,
pipeline_name=tfx_pipeline.pipeline_info.pipeline_name,
params_list=self._params,
package_path=self._output_path,
)
logger.info(
"Finished writing kubeflow pipeline definition file '%s'.",
self._output_path,
)
config: PipelineConfig
property
readonly
The config property
__init__(self, config, output_path, pod_labels_to_attach=None)
special
Initializes KubeflowDagRunner for compiling a Kubeflow Pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
KubeflowDagRunnerConfig |
A KubeflowDagRunnerConfig object to specify runtime configuration when running the pipeline under Kubeflow. |
required |
output_path |
str |
Path where the pipeline definition file will be stored. |
required |
pod_labels_to_attach |
Optional[Dict[str, str]] |
Optional set of pod labels to attach to GKE pod spun up for this pipeline. Default to the 3 labels: 1. add-pod-env: true, 2. pipeline SDK type, 3. pipeline unique ID, where 2 and 3 are instrumentation of usage tracking. |
None |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def __init__(
self,
config: KubeflowDagRunnerConfig,
output_path: str,
pod_labels_to_attach: Optional[Dict[str, str]] = None,
):
"""Initializes KubeflowDagRunner for compiling a Kubeflow Pipeline.
Args:
config: A KubeflowDagRunnerConfig object to specify runtime
configuration when running the pipeline under Kubeflow.
output_path: Path where the pipeline definition file will be stored.
pod_labels_to_attach: Optional set of pod labels to attach to GKE pod
spun up for this pipeline. Default to the 3 labels:
1. add-pod-env: true,
2. pipeline SDK type,
3. pipeline unique ID,
where 2 and 3 are instrumentation of usage tracking.
"""
self._config = config or pipeline_config.PipelineConfig()
self._kubeflow_config = config
self._output_path = output_path
self._compiler = compiler.Compiler()
self._tfx_compiler = tfx_compiler.Compiler()
self._params: List[dsl.PipelineParam] = []
self._params_by_component_id: Dict[
str, List[data_types.RuntimeParameter]
] = collections.defaultdict(list)
self._deduped_parameter_names: Set[str] = set()
self._pod_labels_to_attach = (
pod_labels_to_attach or get_default_pod_labels()
)
run(self, pipeline, stack, runtime_configuration)
Compiles and outputs a Kubeflow Pipeline YAML definition file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
The logical TFX pipeline to use when building the Kubeflow pipeline. |
required |
stack |
Stack |
The ZenML stack that the pipeline is running on. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def run(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Compiles and outputs a Kubeflow Pipeline YAML definition file.
Args:
pipeline: The logical TFX pipeline to use when building the Kubeflow
pipeline.
stack: The ZenML stack that the pipeline is running on.
runtime_configuration: The runtime configuration.
"""
tfx_pipeline = create_tfx_pipeline(pipeline, stack=stack)
pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
if not isinstance(pipeline_root, str):
raise TypeError(
"TFX Pipeline root may not be a Placeholder, "
"but must be a specific string."
)
for component in tfx_pipeline.components:
# TODO(b/187122662): Pass through pip dependencies as a first-class
# component flag.
if isinstance(component, tfx_base_component.BaseComponent):
component._resolve_pip_dependencies(pipeline_root)
def _construct_pipeline() -> None:
"""Creates Kubeflow ContainerOps for each TFX component
encountered in the pipeline definition."""
self._construct_pipeline_graph(
pipeline, tfx_pipeline, stack, runtime_configuration
)
# Need to run this first to get self._params populated. Then KFP
# compiler can correctly match default value with PipelineParam.
self._parse_parameter_from_pipeline(tfx_pipeline)
# Create workflow spec and write out to package.
self._compiler._create_and_write_workflow(
# pylint: disable=protected-access
pipeline_func=_construct_pipeline,
pipeline_name=tfx_pipeline.pipeline_info.pipeline_name,
params_list=self._params,
package_path=self._output_path,
)
logger.info(
"Finished writing kubeflow pipeline definition file '%s'.",
self._output_path,
)
KubeflowDagRunnerConfig (PipelineConfig)
Runtime configuration parameters specific to execution on Kubeflow.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
class KubeflowDagRunnerConfig(pipeline_config.PipelineConfig):
"""Runtime configuration parameters specific to execution on Kubeflow."""
def __init__(
self,
image: str,
pipeline_operator_funcs: Optional[List[OpFunc]] = None,
supported_launcher_classes: Optional[
List[Type[base_component_launcher.BaseComponentLauncher]]
] = None,
**kwargs: Any
):
"""Creates a KubeflowDagRunnerConfig object.
The user can use pipeline_operator_funcs to apply modifications to
ContainerOps used in the pipeline. For example, to ensure the pipeline
steps mount a GCP secret, and a Persistent Volume, one can create config
object like so:
from kfp import gcp, onprem
mount_secret_op = gcp.use_secret('my-secret-name)
mount_volume_op = onprem.mount_pvc(
"my-persistent-volume-claim",
"my-volume-name",
"/mnt/volume-mount-path")
config = KubeflowDagRunnerConfig(
pipeline_operator_funcs=[mount_secret_op, mount_volume_op]
)
Args:
image: The docker image to use in the pipeline.
pipeline_operator_funcs: A list of ContainerOp modifying functions
that will be applied to every container step in the pipeline.
supported_launcher_classes: A list of component launcher classes that
are supported by the current pipeline. List sequence determines the
order in which launchers are chosen for each component being run.
**kwargs: keyword args for PipelineConfig.
"""
supported_launcher_classes = supported_launcher_classes or [
in_process_component_launcher.InProcessComponentLauncher,
kubernetes_component_launcher.KubernetesComponentLauncher,
]
super().__init__(
supported_launcher_classes=supported_launcher_classes, **kwargs
)
self.pipeline_operator_funcs = (
pipeline_operator_funcs or get_default_pipeline_operator_funcs()
)
self.image = image
__init__(self, image, pipeline_operator_funcs=None, supported_launcher_classes=None, **kwargs)
special
Creates a KubeflowDagRunnerConfig object. The user can use pipeline_operator_funcs to apply modifications to ContainerOps used in the pipeline. For example, to ensure the pipeline steps mount a GCP secret, and a Persistent Volume, one can create config object like so: from kfp import gcp, onprem mount_secret_op = gcp.use_secret('my-secret-name) mount_volume_op = onprem.mount_pvc( "my-persistent-volume-claim", "my-volume-name", "/mnt/volume-mount-path") config = KubeflowDagRunnerConfig( pipeline_operator_funcs=[mount_secret_op, mount_volume_op] )
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
str |
The docker image to use in the pipeline. |
required |
pipeline_operator_funcs |
Optional[List[Callable[[kfp.dsl._container_op.ContainerOp], Union[kfp.dsl._container_op.ContainerOp, NoneType]]]] |
A list of ContainerOp modifying functions that will be applied to every container step in the pipeline. |
None |
supported_launcher_classes |
Optional[List[Type[tfx.orchestration.launcher.base_component_launcher.BaseComponentLauncher]]] |
A list of component launcher classes that are supported by the current pipeline. List sequence determines the order in which launchers are chosen for each component being run. |
None |
**kwargs |
Any |
keyword args for PipelineConfig. |
{} |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def __init__(
self,
image: str,
pipeline_operator_funcs: Optional[List[OpFunc]] = None,
supported_launcher_classes: Optional[
List[Type[base_component_launcher.BaseComponentLauncher]]
] = None,
**kwargs: Any
):
"""Creates a KubeflowDagRunnerConfig object.
The user can use pipeline_operator_funcs to apply modifications to
ContainerOps used in the pipeline. For example, to ensure the pipeline
steps mount a GCP secret, and a Persistent Volume, one can create config
object like so:
from kfp import gcp, onprem
mount_secret_op = gcp.use_secret('my-secret-name)
mount_volume_op = onprem.mount_pvc(
"my-persistent-volume-claim",
"my-volume-name",
"/mnt/volume-mount-path")
config = KubeflowDagRunnerConfig(
pipeline_operator_funcs=[mount_secret_op, mount_volume_op]
)
Args:
image: The docker image to use in the pipeline.
pipeline_operator_funcs: A list of ContainerOp modifying functions
that will be applied to every container step in the pipeline.
supported_launcher_classes: A list of component launcher classes that
are supported by the current pipeline. List sequence determines the
order in which launchers are chosen for each component being run.
**kwargs: keyword args for PipelineConfig.
"""
supported_launcher_classes = supported_launcher_classes or [
in_process_component_launcher.InProcessComponentLauncher,
kubernetes_component_launcher.KubernetesComponentLauncher,
]
super().__init__(
supported_launcher_classes=supported_launcher_classes, **kwargs
)
self.pipeline_operator_funcs = (
pipeline_operator_funcs or get_default_pipeline_operator_funcs()
)
self.image = image
get_default_pipeline_operator_funcs(use_gcp_sa=False)
Returns a default list of pipeline operator functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
use_gcp_sa |
bool |
If true, mount a GCP service account secret to each pod, with the name _KUBEFLOW_GCP_SECRET_NAME. |
False |
Returns:
Type | Description |
---|---|
List[Callable[[kfp.dsl._container_op.ContainerOp], Optional[kfp.dsl._container_op.ContainerOp]]] |
A list of functions with type OpFunc. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def get_default_pipeline_operator_funcs(
use_gcp_sa: bool = False,
) -> List[OpFunc]:
"""Returns a default list of pipeline operator functions.
Args:
use_gcp_sa: If true, mount a GCP service account secret to each pod, with
the name _KUBEFLOW_GCP_SECRET_NAME.
Returns:
A list of functions with type OpFunc.
"""
# Enables authentication for GCP services if needed.
gcp_secret_op = gcp.use_gcp_secret(_KUBEFLOW_GCP_SECRET_NAME)
# Mounts configmap containing Metadata gRPC server configuration.
mount_config_map_op = _mount_config_map_op("metadata-grpc-configmap")
if use_gcp_sa:
return [gcp_secret_op, mount_config_map_op]
else:
return [mount_config_map_op]
get_default_pod_labels()
Returns the default pod label dict for Kubeflow.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py
def get_default_pod_labels() -> Dict[str, str]:
"""Returns the default pod label dict for Kubeflow."""
# KFP default transformers add pod env:
# https://github.com/kubeflow/pipelines/blob/0.1.32/sdk/python/kfp/compiler/_default_transformers.py
result = {"add-pod-env": "true", telemetry_utils.LABEL_KFP_SDK_ENV: "tfx"}
return result
kubeflow_orchestrator
KubeflowOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines using Kubeflow.
Attributes:
Name | Type | Description |
---|---|---|
custom_docker_base_image_name |
Optional[str] |
Name of a docker image that should be used as the base for the image that will be run on KFP pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/ |
kubeflow_pipelines_ui_port |
int |
A local port to which the KFP UI will be forwarded. |
kubernetes_context |
Optional[str] |
Optional name of a kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running |
synchronous |
If |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
class KubeflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Kubeflow.
Attributes:
custom_docker_base_image_name: Name of a docker image that should be
used as the base for the image that will be run on KFP pods. If no
custom image is given, a basic image of the active ZenML version
will be used. **Note**: This image needs to have ZenML installed,
otherwise the pipeline execution will fail. For that reason, you
might want to extend the ZenML docker images found here:
https://hub.docker.com/r/zenmldocker/zenml/
kubeflow_pipelines_ui_port: A local port to which the KFP UI will be
forwarded.
kubernetes_context: Optional name of a kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running `kubectl config
current-context`.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on KFP.
"""
custom_docker_base_image_name: Optional[str] = None
kubeflow_pipelines_ui_port: int = DEFAULT_KFP_UI_PORT
kubernetes_context: Optional[str] = None
synchronous = False
supports_local_execution = True
supports_remote_execution = True
@property
def flavor(self) -> OrchestratorFlavor:
"""The orchestrator flavor."""
return OrchestratorFlavor.KUBEFLOW
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry."""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY}
)
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag."""
base_image_name = f"zenml-kubeflow:{pipeline_name}"
container_registry = Repository().active_stack.container_registry
if container_registry:
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/{base_image_name}"
else:
return base_image_name
@property
def root_directory(self) -> str:
"""Returns path to the root directory for all files concerning
this orchestrator."""
return os.path.join(
zenml.io.utils.get_global_config_directory(),
"kubeflow",
str(self.uuid),
)
@property
def pipeline_directory(self) -> str:
"""Returns path to a directory in which the kubeflow pipeline files
are stored."""
return os.path.join(self.root_directory, "pipelines")
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Builds a docker image for the current environment and uploads it to
a container registry if configured.
"""
from zenml.utils.docker_utils import (
build_docker_image,
push_docker_image,
)
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug("Kubeflow docker container requirements: %s", requirements)
build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
environment_vars=self._get_environment_vars_from_secrets(
pipeline.secrets
),
)
if stack.container_registry:
push_docker_image(image_name)
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline on Kubeflow Pipelines."""
# First check whether its running in a notebok
from zenml.environment import Environment
if Environment.in_notebook():
raise RuntimeError(
"The Kubeflow orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
from zenml.utils.docker_utils import get_image_digest
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
fileio.make_dirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{pipeline.name}.yaml"
)
runner_config = KubeflowDagRunnerConfig(image=image_name)
runner = KubeflowDagRunner(
config=runner_config, output_path=pipeline_file_path
)
runner.run(
pipeline=pipeline,
stack=stack,
runtime_configuration=runtime_configuration,
)
self._upload_and_run_pipeline(
pipeline_name=pipeline.name,
pipeline_file_path=pipeline_file_path,
runtime_configuration=runtime_configuration,
enable_cache=pipeline.enable_cache,
)
def _upload_and_run_pipeline(
self,
pipeline_name: str,
pipeline_file_path: str,
runtime_configuration: "RuntimeConfiguration",
enable_cache: bool,
) -> None:
"""Tries to upload and run a KFP pipeline.
Args:
pipeline_name: Name of the pipeline.
pipeline_file_path: Path to the pipeline definition file.
runtime_configuration: Runtime configuration of the pipeline run.
enable_cache: Whether caching is enabled for this pipeline run.
"""
try:
if self.kubernetes_context:
logger.info(
"Running in kubernetes context '%s'.",
self.kubernetes_context,
)
# upload the pipeline to Kubeflow and start it
client = kfp.Client(kube_context=self.kubernetes_context)
if runtime_configuration.schedule:
try:
experiment = client.get_experiment(pipeline_name)
logger.info(
"A recurring run has already been created with this "
"pipeline. Creating new recurring run now.."
)
except (ValueError, ApiException):
experiment = client.create_experiment(pipeline_name)
logger.info(
"Creating a new recurring run for pipeline '%s'.. ",
pipeline_name,
)
logger.info(
"You can see all recurring runs under the '%s' experiment.'",
pipeline_name,
)
schedule = runtime_configuration.schedule
result = client.create_recurring_run(
experiment_id=experiment.id,
job_name=runtime_configuration.run_name,
pipeline_package_path=pipeline_file_path,
enable_caching=enable_cache,
start_time=schedule.utc_start_time,
end_time=schedule.utc_end_time,
interval_second=schedule.interval_second,
no_catchup=not schedule.catchup,
)
logger.info("Started recurring run with ID '%s'.", result.id)
else:
logger.info(
"No schedule detected. Creating a one-off pipeline run.."
)
result = client.create_run_from_pipeline_package(
pipeline_file_path,
arguments={},
run_name=runtime_configuration.run_name,
enable_caching=enable_cache,
)
logger.info(
"Started one-off pipeline run with ID '%s'.", result.run_id
)
if self.synchronous:
# TODO [ENG-698]: Allow configuration of the timeout as a
# runtime option
client.wait_for_run_completion(
run_id=result.run_id, timeout=1200
)
except urllib3.exceptions.HTTPError as error:
logger.warning(
"Failed to upload Kubeflow pipeline: %s. "
"Please make sure your kube config is configured and the "
"current context is set correctly.",
error,
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file."""
return os.path.join(self.root_directory, "kubeflow_daemon.pid")
@property
def log_file(self) -> str:
"""Path of the daemon log file."""
return os.path.join(self.root_directory, "kubeflow_daemon.log")
@property
def _k3d_cluster_name(self) -> str:
"""Returns the K3D cluster name."""
# K3D only allows cluster names with up to 32 characters, use the
# first 8 chars of the orchestrator UUID as identifier
return f"zenml-kubeflow-{str(self.uuid)[:8]}"
def _get_k3d_registry_name(self, port: int) -> str:
"""Returns the K3D registry name."""
return f"k3d-zenml-kubeflow-registry.localhost:{port}"
@property
def _k3d_registry_config_path(self) -> str:
"""Returns the path to the K3D registry config yaml."""
return os.path.join(self.root_directory, "k3d_registry.yaml")
def _get_kfp_ui_daemon_port(self) -> int:
"""Port to use for the KFP UI daemon."""
port = self.kubeflow_pipelines_ui_port
if port == DEFAULT_KFP_UI_PORT and not networking_utils.port_available(
port
):
# if the user didn't specify a specific port and the default
# port is occupied, fallback to a random open port
port = networking_utils.find_available_port()
return port
def list_manual_setup_steps(
self, container_registry_name: str, container_registry_path: str
) -> None:
"""Logs manual steps needed to setup the Kubeflow local orchestrator."""
global_config_dir_path = zenml.io.utils.get_global_config_directory()
kubeflow_commands = [
f"> k3d cluster create CLUSTER_NAME --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
f"> kubectl --context CLUSTER_NAME apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=1m",
"> kubectl --context CLUSTER_NAME wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
f"> kubectl --context CLUSTER_NAME apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=1m",
f"> kubectl --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
]
logger.error("Unable to spin up local Kubeflow Pipelines deployment.")
logger.info(
"If you wish to spin up this Kubeflow local orchestrator manually, "
"please enter the following commands (substituting where appropriate):\n"
)
logger.info("\n".join(kubeflow_commands))
@property
def is_provisioned(self) -> bool:
"""Returns if a local k3d cluster for this orchestrator exists."""
if not local_deployment_utils.check_prerequisites():
# if any prerequisites are missing there is certainly no
# local deployment running
return False
return local_deployment_utils.k3d_cluster_exists(
cluster_name=self._k3d_cluster_name
)
@property
def is_running(self) -> bool:
"""Returns if the local k3d cluster for this orchestrator is running."""
if not self.is_provisioned:
return False
return local_deployment_utils.k3d_cluster_running(
cluster_name=self._k3d_cluster_name
)
def provision(self) -> None:
"""Provisions a local Kubeflow Pipelines deployment."""
if self.is_running:
logger.info(
"Found already existing local Kubeflow Pipelines deployment. "
"If there are any issues with the existing deployment, please "
"run 'zenml stack down --force' to delete it."
)
return
if not local_deployment_utils.check_prerequisites():
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Please install 'k3d' and 'kubectl' and try again."
)
container_registry = Repository().active_stack.container_registry
if not container_registry:
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Missing container registry in current stack."
)
if not re.fullmatch(r"localhost:[0-9]{4,5}", container_registry.uri):
raise ProvisioningError(
f"Container registry URI '{container_registry.uri}' doesn't "
f"match the expected format 'localhost:$PORT'. Provisioning "
f"stack resources only works for local container registries."
)
logger.info("Provisioning local Kubeflow Pipelines deployment...")
fileio.make_dirs(self.root_directory)
container_registry_port = int(container_registry.uri.split(":")[-1])
container_registry_name = self._get_k3d_registry_name(
port=container_registry_port
)
local_deployment_utils.write_local_registry_yaml(
yaml_path=self._k3d_registry_config_path,
registry_name=container_registry_name,
registry_uri=container_registry.uri,
)
try:
local_deployment_utils.create_k3d_cluster(
cluster_name=self._k3d_cluster_name,
registry_name=container_registry_name,
registry_config_path=self._k3d_registry_config_path,
)
kubernetes_context = f"k3d-{self._k3d_cluster_name}"
local_deployment_utils.deploy_kubeflow_pipelines(
kubernetes_context=kubernetes_context
)
artifact_store = Repository().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
kubernetes_context=kubernetes_context,
local_path=artifact_store.path,
)
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
)
except Exception as e:
logger.error(e)
self.list_manual_setup_steps(
container_registry_name, self._k3d_registry_config_path
)
self.deprovision()
def deprovision(self) -> None:
"""Deprovisions a local Kubeflow Pipelines deployment."""
if self.is_running:
local_deployment_utils.delete_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if fileio.file_exists(self.log_file):
fileio.remove(self.log_file)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
def resume(self) -> None:
"""Resumes the local k3d cluster."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
if not self.is_provisioned:
raise ProvisioningError(
"Unable to resume local kubeflow pipelines deployment: No "
"resources provisioned for local deployment."
)
local_deployment_utils.start_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
kubernetes_context = f"k3d-{self._k3d_cluster_name}"
local_deployment_utils.wait_until_kubeflow_pipelines_ready(
kubernetes_context=kubernetes_context
)
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
)
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_running:
logger.info("Local kubeflow pipelines deployment not running.")
return
local_deployment_utils.stop_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
def _get_environment_vars_from_secrets(
self, secrets: List[str]
) -> Dict[str, str]:
"""Get key-value pairs from list of secrets provided by the user.
Args:
secrets: List of secrets provided by the user.
Returns:
A dictionary of key-value pairs.
Raises:
ProvisioningError: If the stack has no secrets manager."""
secret_manager = Repository().active_stack.secrets_manager
if not secret_manager:
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Missing secrets manager in current stack."
)
environment_vars: Dict[str, str] = {}
for secret in secrets:
environment_vars.update(secret_manager.get_secret(secret))
return environment_vars
flavor: OrchestratorFlavor
property
readonly
The orchestrator flavor.
is_provisioned: bool
property
readonly
Returns if a local k3d cluster for this orchestrator exists.
is_running: bool
property
readonly
Returns if the local k3d cluster for this orchestrator is running.
log_file: str
property
readonly
Path of the daemon log file.
pipeline_directory: str
property
readonly
Returns path to a directory in which the kubeflow pipeline files are stored.
root_directory: str
property
readonly
Returns path to the root directory for all files concerning this orchestrator.
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
deprovision(self)
Deprovisions a local Kubeflow Pipelines deployment.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def deprovision(self) -> None:
"""Deprovisions a local Kubeflow Pipelines deployment."""
if self.is_running:
local_deployment_utils.delete_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if fileio.file_exists(self.log_file):
fileio.remove(self.log_file)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
get_docker_image_name(self, pipeline_name)
Returns the full docker image name including registry and tag.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag."""
base_image_name = f"zenml-kubeflow:{pipeline_name}"
container_registry = Repository().active_stack.container_registry
if container_registry:
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/{base_image_name}"
else:
return base_image_name
list_manual_setup_steps(self, container_registry_name, container_registry_path)
Logs manual steps needed to setup the Kubeflow local orchestrator.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def list_manual_setup_steps(
self, container_registry_name: str, container_registry_path: str
) -> None:
"""Logs manual steps needed to setup the Kubeflow local orchestrator."""
global_config_dir_path = zenml.io.utils.get_global_config_directory()
kubeflow_commands = [
f"> k3d cluster create CLUSTER_NAME --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
f"> kubectl --context CLUSTER_NAME apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=1m",
"> kubectl --context CLUSTER_NAME wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
f"> kubectl --context CLUSTER_NAME apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=1m",
f"> kubectl --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
]
logger.error("Unable to spin up local Kubeflow Pipelines deployment.")
logger.info(
"If you wish to spin up this Kubeflow local orchestrator manually, "
"please enter the following commands (substituting where appropriate):\n"
)
logger.info("\n".join(kubeflow_commands))
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Builds a docker image for the current environment and uploads it to a container registry if configured.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Builds a docker image for the current environment and uploads it to
a container registry if configured.
"""
from zenml.utils.docker_utils import (
build_docker_image,
push_docker_image,
)
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug("Kubeflow docker container requirements: %s", requirements)
build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
environment_vars=self._get_environment_vars_from_secrets(
pipeline.secrets
),
)
if stack.container_registry:
push_docker_image(image_name)
provision(self)
Provisions a local Kubeflow Pipelines deployment.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def provision(self) -> None:
"""Provisions a local Kubeflow Pipelines deployment."""
if self.is_running:
logger.info(
"Found already existing local Kubeflow Pipelines deployment. "
"If there are any issues with the existing deployment, please "
"run 'zenml stack down --force' to delete it."
)
return
if not local_deployment_utils.check_prerequisites():
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Please install 'k3d' and 'kubectl' and try again."
)
container_registry = Repository().active_stack.container_registry
if not container_registry:
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Missing container registry in current stack."
)
if not re.fullmatch(r"localhost:[0-9]{4,5}", container_registry.uri):
raise ProvisioningError(
f"Container registry URI '{container_registry.uri}' doesn't "
f"match the expected format 'localhost:$PORT'. Provisioning "
f"stack resources only works for local container registries."
)
logger.info("Provisioning local Kubeflow Pipelines deployment...")
fileio.make_dirs(self.root_directory)
container_registry_port = int(container_registry.uri.split(":")[-1])
container_registry_name = self._get_k3d_registry_name(
port=container_registry_port
)
local_deployment_utils.write_local_registry_yaml(
yaml_path=self._k3d_registry_config_path,
registry_name=container_registry_name,
registry_uri=container_registry.uri,
)
try:
local_deployment_utils.create_k3d_cluster(
cluster_name=self._k3d_cluster_name,
registry_name=container_registry_name,
registry_config_path=self._k3d_registry_config_path,
)
kubernetes_context = f"k3d-{self._k3d_cluster_name}"
local_deployment_utils.deploy_kubeflow_pipelines(
kubernetes_context=kubernetes_context
)
artifact_store = Repository().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
kubernetes_context=kubernetes_context,
local_path=artifact_store.path,
)
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
)
except Exception as e:
logger.error(e)
self.list_manual_setup_steps(
container_registry_name, self._k3d_registry_config_path
)
self.deprovision()
resume(self)
Resumes the local k3d cluster.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def resume(self) -> None:
"""Resumes the local k3d cluster."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
if not self.is_provisioned:
raise ProvisioningError(
"Unable to resume local kubeflow pipelines deployment: No "
"resources provisioned for local deployment."
)
local_deployment_utils.start_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
kubernetes_context = f"k3d-{self._k3d_cluster_name}"
local_deployment_utils.wait_until_kubeflow_pipelines_ready(
kubernetes_context=kubernetes_context
)
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
)
run_pipeline(self, pipeline, stack, runtime_configuration)
Runs a pipeline on Kubeflow Pipelines.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline on Kubeflow Pipelines."""
# First check whether its running in a notebok
from zenml.environment import Environment
if Environment.in_notebook():
raise RuntimeError(
"The Kubeflow orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
from zenml.utils.docker_utils import get_image_digest
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
fileio.make_dirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{pipeline.name}.yaml"
)
runner_config = KubeflowDagRunnerConfig(image=image_name)
runner = KubeflowDagRunner(
config=runner_config, output_path=pipeline_file_path
)
runner.run(
pipeline=pipeline,
stack=stack,
runtime_configuration=runtime_configuration,
)
self._upload_and_run_pipeline(
pipeline_name=pipeline.name,
pipeline_file_path=pipeline_file_path,
runtime_configuration=runtime_configuration,
enable_cache=pipeline.enable_cache,
)
suspend(self)
Suspends the local k3d cluster.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_running:
logger.info("Local kubeflow pipelines deployment not running.")
return
local_deployment_utils.stop_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
kubeflow_utils
Common utility for Kubeflow-based orchestrator.
replace_placeholder(component)
Replaces the RuntimeParameter placeholders with kfp.dsl.PipelineParam.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_utils.py
def replace_placeholder(component: base_node.BaseNode) -> None:
"""Replaces the RuntimeParameter placeholders with kfp.dsl.PipelineParam."""
keys = list(component.exec_properties.keys())
for key in keys:
exec_property = component.exec_properties[key]
if not isinstance(exec_property, data_types.RuntimeParameter):
continue
component.exec_properties[key] = str(
dsl.PipelineParam(name=exec_property.name)
)
local_deployment_utils
add_hostpath_to_kubeflow_pipelines(kubernetes_context, local_path)
Patches the Kubeflow Pipelines deployment to mount a local folder as a hostpath for visualization purposes.
This function reconfigures the Kubeflow pipelines deployment to use a shared local folder to support loading the Tensorboard viewer and other pipeline visualization results from a local artifact store, as described here:
https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context on which Kubeflow Pipelines should be patched. |
required |
local_path |
str |
The path to the local folder to mount as a hostpath. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def add_hostpath_to_kubeflow_pipelines(
kubernetes_context: str, local_path: str
) -> None:
"""Patches the Kubeflow Pipelines deployment to mount a local folder as
a hostpath for visualization purposes.
This function reconfigures the Kubeflow pipelines deployment to use a
shared local folder to support loading the Tensorboard viewer and other
pipeline visualization results from a local artifact store, as described
here:
https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md
Args:
kubernetes_context: The kubernetes context on which Kubeflow Pipelines
should be patched.
local_path: The path to the local folder to mount as a hostpath.
"""
logger.info("Patching Kubeflow Pipelines to mount a local folder.")
pod_template = {
"spec": {
"serviceAccountName": "kubeflow-pipelines-viewer",
"containers": [
{
"volumeMounts": [
{
"mountPath": local_path,
"name": "local-artifact-store",
}
]
}
],
"volumes": [
{
"hostPath": {
"path": local_path,
"type": "Directory",
},
"name": "local-artifact-store",
}
],
}
}
pod_template_json = json.dumps(pod_template, indent=2)
config_map_data = {"data": {"viewer-pod-template.json": pod_template_json}}
config_map_data_json = json.dumps(config_map_data, indent=2)
logger.debug(
"Adding host path volume for local path `%s` to kubeflow pipeline"
"viewer pod template configuration.",
local_path,
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"-n",
"kubeflow",
"patch",
"configmap/ml-pipeline-ui-configmap",
"--type",
"merge",
"-p",
config_map_data_json,
]
)
deployment_patch = {
"spec": {
"template": {
"spec": {
"containers": [
{
"name": "ml-pipeline-ui",
"volumeMounts": [
{
"mountPath": local_path,
"name": "local-artifact-store",
}
],
}
],
"volumes": [
{
"hostPath": {
"path": local_path,
"type": "Directory",
},
"name": "local-artifact-store",
}
],
}
}
}
}
deployment_patch_json = json.dumps(deployment_patch, indent=2)
logger.debug(
"Adding host path volume for local path `%s` to the kubeflow UI",
local_path,
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"-n",
"kubeflow",
"patch",
"deployment/ml-pipeline-ui",
"--type",
"strategic",
"-p",
deployment_patch_json,
]
)
wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
logger.info("Finished patching Kubeflow Pipelines setup.")
check_prerequisites()
Checks whether all prerequisites for a local kubeflow pipelines deployment are installed.
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def check_prerequisites() -> bool:
"""Checks whether all prerequisites for a local kubeflow pipelines
deployment are installed."""
k3d_installed = shutil.which("k3d") is not None
kubectl_installed = shutil.which("kubectl") is not None
logger.debug(
"Local kubeflow deployment prerequisites: K3D - %s, Kubectl - %s",
k3d_installed,
kubectl_installed,
)
return k3d_installed and kubectl_installed
create_k3d_cluster(cluster_name, registry_name, registry_config_path)
Creates a K3D cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to create. |
required |
registry_name |
str |
Name of the registry to create for this cluster. |
required |
registry_config_path |
str |
Path to the registry config file. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def create_k3d_cluster(
cluster_name: str, registry_name: str, registry_config_path: str
) -> None:
"""Creates a K3D cluster.
Args:
cluster_name: Name of the cluster to create.
registry_name: Name of the registry to create for this cluster.
registry_config_path: Path to the registry config file.
"""
logger.info("Creating local K3D cluster '%s'.", cluster_name)
global_config_dir_path = zenml.io.utils.get_global_config_directory()
subprocess.check_call(
[
"k3d",
"cluster",
"create",
cluster_name,
"--image",
K3S_IMAGE_NAME,
"--registry-create",
registry_name,
"--registry-config",
registry_config_path,
"--volume",
f"{global_config_dir_path}:{global_config_dir_path}",
]
)
logger.info("Finished K3D cluster creation.")
delete_k3d_cluster(cluster_name)
Deletes a K3D cluster with the given name.
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def delete_k3d_cluster(cluster_name: str) -> None:
"""Deletes a K3D cluster with the given name."""
subprocess.check_call(["k3d", "cluster", "delete", cluster_name])
logger.info("Deleted local k3d cluster '%s'.", cluster_name)
deploy_kubeflow_pipelines(kubernetes_context)
Deploys Kubeflow Pipelines.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context on which Kubeflow Pipelines should be deployed. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def deploy_kubeflow_pipelines(kubernetes_context: str) -> None:
"""Deploys Kubeflow Pipelines.
Args:
kubernetes_context: The kubernetes context on which Kubeflow Pipelines
should be deployed.
"""
logger.info("Deploying Kubeflow Pipelines.")
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"apply",
"-k",
f"github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=1m",
]
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"wait",
"--timeout=60s",
"--for",
"condition=established",
"crd/applications.app.k8s.io",
]
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"apply",
"-k",
f"github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=1m",
]
)
wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
logger.info("Finished Kubeflow Pipelines setup.")
k3d_cluster_exists(cluster_name)
Checks whether there exists a K3D cluster with the given name.
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_exists(cluster_name: str) -> bool:
"""Checks whether there exists a K3D cluster with the given name."""
output = subprocess.check_output(
["k3d", "cluster", "list", "--output", "json"]
)
clusters = json.loads(output)
for cluster in clusters:
if cluster["name"] == cluster_name:
return True
return False
k3d_cluster_running(cluster_name)
Checks whether the K3D cluster with the given name is running.
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_running(cluster_name: str) -> bool:
"""Checks whether the K3D cluster with the given name is running."""
output = subprocess.check_output(
["k3d", "cluster", "list", "--output", "json"]
)
clusters = json.loads(output)
for cluster in clusters:
if cluster["name"] == cluster_name:
server_count: int = cluster["serversCount"]
servers_running: int = cluster["serversRunning"]
return servers_running == server_count
return False
kubeflow_pipelines_ready(kubernetes_context)
Returns whether all Kubeflow Pipelines pods are ready.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context in which the pods should be checked. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def kubeflow_pipelines_ready(kubernetes_context: str) -> bool:
"""Returns whether all Kubeflow Pipelines pods are ready.
Args:
kubernetes_context: The kubernetes context in which the pods
should be checked.
"""
try:
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"wait",
"--for",
"condition=ready",
"--timeout=0s",
"pods",
"--all",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return True
except subprocess.CalledProcessError:
return False
start_k3d_cluster(cluster_name)
Starts a K3D cluster with the given name.
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_k3d_cluster(cluster_name: str) -> None:
"""Starts a K3D cluster with the given name."""
subprocess.check_call(["k3d", "cluster", "start", cluster_name])
logger.info("Started local k3d cluster '%s'.", cluster_name)
start_kfp_ui_daemon(pid_file_path, log_file_path, port)
Starts a daemon process that forwards ports so the Kubeflow Pipelines UI is accessible in the browser.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file_path |
str |
Path where the file with the daemons process ID should be written. |
required |
log_file_path |
str |
Path to a file where the daemon logs should be written. |
required |
port |
int |
Port on which the UI should be accessible. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_kfp_ui_daemon(
pid_file_path: str, log_file_path: str, port: int
) -> None:
"""Starts a daemon process that forwards ports so the Kubeflow Pipelines
UI is accessible in the browser.
Args:
pid_file_path: Path where the file with the daemons process ID should
be written.
log_file_path: Path to a file where the daemon logs should be written.
port: Port on which the UI should be accessible.
"""
command = [
"kubectl",
"--namespace",
"kubeflow",
"port-forward",
"svc/ml-pipeline-ui",
f"{port}:80",
]
if not networking_utils.port_available(port):
modified_command = command.copy()
modified_command[-1] = "PORT:80"
logger.warning(
"Unable to port-forward Kubeflow Pipelines UI to local port %d "
"because the port is occupied. In order to access the Kubeflow "
"Pipelines UI at http://localhost:PORT/, please run '%s' in a "
"separate command line shell (replace PORT with a free port of "
"your choice).",
port,
" ".join(modified_command),
)
elif sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubeflow Pipelines UI at "
"http://localhost:%d/, please run '%s' in a separate command "
"line shell.",
port,
" ".join(command),
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Port-forwards the Kubeflow Pipelines UI pod."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function, pid_file=pid_file_path, log_file=log_file_path
)
logger.info(
"Started Kubeflow Pipelines UI daemon (check the daemon logs at %s "
"in case you're not able to view the UI). The Kubeflow Pipelines "
"UI should now be accessible at http://localhost:%d/.",
log_file_path,
port,
)
stop_k3d_cluster(cluster_name)
Stops a K3D cluster with the given name.
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_k3d_cluster(cluster_name: str) -> None:
"""Stops a K3D cluster with the given name."""
subprocess.check_call(["k3d", "cluster", "stop", cluster_name])
logger.info("Stopped local k3d cluster '%s'.", cluster_name)
stop_kfp_ui_daemon(pid_file_path)
Stops the KFP UI daemon process if it is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file_path |
str |
Path to the file with the daemons process ID. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_kfp_ui_daemon(pid_file_path: str) -> None:
"""Stops the KFP UI daemon process if it is running.
Args:
pid_file_path: Path to the file with the daemons process ID.
"""
if fileio.file_exists(pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(pid_file_path)
fileio.remove(pid_file_path)
wait_until_kubeflow_pipelines_ready(kubernetes_context)
Waits until all Kubeflow Pipelines pods are ready.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context in which the pods should be checked. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def wait_until_kubeflow_pipelines_ready(kubernetes_context: str) -> None:
"""Waits until all Kubeflow Pipelines pods are ready.
Args:
kubernetes_context: The kubernetes context in which the pods
should be checked.
"""
logger.info(
"Waiting for all Kubeflow Pipelines pods to be ready (this might "
"take a few minutes)."
)
while True:
logger.info("Current pod status:")
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"get",
"pods",
]
)
if kubeflow_pipelines_ready(kubernetes_context=kubernetes_context):
break
logger.info("One or more pods not ready yet, waiting for 30 seconds...")
time.sleep(30)
write_local_registry_yaml(yaml_path, registry_name, registry_uri)
Writes a K3D registry config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_path |
str |
Path where the config file should be written to. |
required |
registry_name |
str |
Name of the registry. |
required |
registry_uri |
str |
URI of the registry. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def write_local_registry_yaml(
yaml_path: str, registry_name: str, registry_uri: str
) -> None:
"""Writes a K3D registry config file.
Args:
yaml_path: Path where the config file should be written to.
registry_name: Name of the registry.
registry_uri: URI of the registry.
"""
yaml_content = {
"mirrors": {registry_uri: {"endpoint": [f"http://{registry_name}"]}}
}
yaml_utils.write_yaml(yaml_path, yaml_content)
mlflow
special
The mlflow integrations currently enables you to use mlflow tracking as a convenient way to visualize your experiment runs within the mlflow ui
MlflowIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = MLFLOW
REQUIREMENTS = [
"mlflow>=1.2.0",
"mlserver>=0.5.3",
"mlserver-mlflow>=0.5.3",
]
@staticmethod
def activate() -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
from zenml.integrations.mlflow.mlflow_environment import (
MLFlowEnvironment,
)
# Create and activate the global MLflow environment
MLFlowEnvironment().activate()
activate()
staticmethod
Activate the MLflow integration.
Source code in zenml/integrations/mlflow/__init__.py
@staticmethod
def activate() -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
from zenml.integrations.mlflow.mlflow_environment import (
MLFlowEnvironment,
)
# Create and activate the global MLflow environment
MLFlowEnvironment().activate()
mlflow_environment
MLFlowEnvironment (BaseEnvironmentComponent)
Manages the global MLflow environment in the form of an Environment component. To access it inside your step function or in the post-execution workflow:
from zenml.environment import Environment
from zenml.integrations.mlflow.mlflow_environment import MLFLOW_ENVIRONMENT_NAME
@step
def my_step(...)
env = Environment[MLFLOW_ENVIRONMENT_NAME]
do_something_with(env.mlflow_tracking_uri)
Source code in zenml/integrations/mlflow/mlflow_environment.py
class MLFlowEnvironment(BaseEnvironmentComponent):
"""Manages the global MLflow environment in the form of an Environment
component. To access it inside your step function or in the post-execution
workflow:
```python
from zenml.environment import Environment
from zenml.integrations.mlflow.mlflow_environment import MLFLOW_ENVIRONMENT_NAME
@step
def my_step(...)
env = Environment[MLFLOW_ENVIRONMENT_NAME]
do_something_with(env.mlflow_tracking_uri)
```
"""
NAME = MLFLOW_ENVIRONMENT_NAME
def __init__(self) -> None:
"""Initialize a MLflow environment component."""
super().__init__()
# TODO [ENG-316]: Implement a way to get the mlflow token and set
# it as env variable at MLFLOW_TRACKING_TOKEN
self._mlflow_tracking_uri = self._local_mlflow_backend()
@staticmethod
def _local_mlflow_backend() -> str:
"""Returns the local mlflow backend inside the zenml artifact
repository directory
Returns:
The MLflow tracking URI for the local mlflow backend.
"""
repo = Repository()
artifact_store = repo.active_stack.artifact_store
local_mlflow_backend_uri = os.path.join(artifact_store.path, "mlruns")
if not os.path.exists(local_mlflow_backend_uri):
os.makedirs(local_mlflow_backend_uri)
# TODO [medium]: safely access (possibly non-existent) artifact stores
return "file:" + local_mlflow_backend_uri
def activate(self) -> None:
"""Activate the MLflow environment for the current stack."""
logger.debug(
"Setting the MLflow tracking uri to %s", self._mlflow_tracking_uri
)
set_tracking_uri(self._mlflow_tracking_uri)
return super().activate()
def deactivate(self) -> None:
logger.debug("Resetting the MLflow tracking uri to local")
set_tracking_uri("")
return super().deactivate()
@property
def tracking_uri(self) -> str:
"""Returns the MLflow tracking URI for the current stack."""
return self._mlflow_tracking_uri
tracking_uri: str
property
readonly
Returns the MLflow tracking URI for the current stack.
__init__(self)
special
Initialize a MLflow environment component.
Source code in zenml/integrations/mlflow/mlflow_environment.py
def __init__(self) -> None:
"""Initialize a MLflow environment component."""
super().__init__()
# TODO [ENG-316]: Implement a way to get the mlflow token and set
# it as env variable at MLFLOW_TRACKING_TOKEN
self._mlflow_tracking_uri = self._local_mlflow_backend()
activate(self)
Activate the MLflow environment for the current stack.
Source code in zenml/integrations/mlflow/mlflow_environment.py
def activate(self) -> None:
"""Activate the MLflow environment for the current stack."""
logger.debug(
"Setting the MLflow tracking uri to %s", self._mlflow_tracking_uri
)
set_tracking_uri(self._mlflow_tracking_uri)
return super().activate()
deactivate(self)
Deactivate the environment component and deregister it from the global Environment.
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the component is not active. |
Source code in zenml/integrations/mlflow/mlflow_environment.py
def deactivate(self) -> None:
logger.debug("Resetting the MLflow tracking uri to local")
set_tracking_uri("")
return super().deactivate()
MLFlowStepEnvironment (BaseEnvironmentComponent)
Provides information about an MLflow step environment. To access it inside your step function:
from zenml.environment import Environment
from zenml.integrations.mlflow.mlflow_environment import MLFLOW_STEP_ENVIRONMENT_NAME
@step
def my_step(...)
env = Environment[MLFLOW_STEP_ENVIRONMENT_NAME]
do_something_with(env.mlflow_tracking_uri)
Source code in zenml/integrations/mlflow/mlflow_environment.py
class MLFlowStepEnvironment(BaseEnvironmentComponent):
"""Provides information about an MLflow step environment.
To access it inside your step function:
```python
from zenml.environment import Environment
from zenml.integrations.mlflow.mlflow_environment import MLFLOW_STEP_ENVIRONMENT_NAME
@step
def my_step(...)
env = Environment[MLFLOW_STEP_ENVIRONMENT_NAME]
do_something_with(env.mlflow_tracking_uri)
```
"""
NAME = MLFLOW_STEP_ENVIRONMENT_NAME
def __init__(self, experiment_name: str, run_name: str):
"""Initialize a MLflow step environment component.
Args:
experiment_name: the experiment name under which all MLflow
artifacts logged under the current step will be tracked.
If no MLflow experiment with this name exists, one will
be created when the environment component is activated.
run_name: the name of the MLflow run associated with the current
step. If a run with this name does not exist, one will be
created when the environment component is activated,
otherwise the existing run will be reused.
"""
super().__init__()
self._experiment_name = experiment_name
self._experiment = None
self._run_name = run_name
self._run = None
def _create_or_reuse_mlflow_run(self) -> None:
"""Create or reuse an MLflow run for the current step.
IMPORTANT: this function might cause a race condition. If two or more
processes call it at the same time and with the same arguments, it could
lead to a situation where two or more MLflow runs with the same name
and different IDs are created.
"""
# Set which experiment is used within mlflow
logger.debug(
"Setting the MLflow experiment name to %s", self._experiment_name
)
set_experiment(self._experiment_name)
self._experiment = get_experiment_by_name(self._experiment_name)
if self._experiment is None:
raise RuntimeError(
f"Failed to create or reuse MLflow "
f"experiment {self._experiment_name}"
)
experiment_id = self._experiment.experiment_id
# TODO [ENG-458]: find a solution to avoid race-conditions while creating
# the same MLflow run from parallel steps
runs = search_runs(
experiment_ids=[experiment_id],
filter_string=f'tags.mlflow.runName = "{self._run_name}"',
output_format="list",
)
if runs:
run_id = runs[0].info.run_id
self._run = start_run(run_id=run_id, experiment_id=experiment_id)
else:
self._run = start_run(
run_name=self._run_name, experiment_id=experiment_id
)
if self._run is None:
raise RuntimeError(
f"Failed to create or reuse MLflow "
f"run {self._run_name} for experiment {self._experiment_name}"
)
def activate(self) -> None:
"""Activate the MLflow environment for the current step."""
self._create_or_reuse_mlflow_run()
return super().activate()
def deactivate(self) -> None:
return super().deactivate()
@property
def mlflow_experiment_name(self) -> str:
"""Returns the MLflow experiment name for the current step."""
return self._experiment_name
@property
def mlflow_run_name(self) -> str:
"""Returns the MLflow run name for the current step."""
return self._run_name
@property
def mlflow_experiment(self) -> Optional[Experiment]:
"""Returns the MLflow experiment object for the current step."""
return self._experiment
@property
def mlflow_run(self) -> Optional[ActiveRun]:
"""Returns the MLflow run for the current step."""
return self._run
mlflow_experiment: Optional[mlflow.entities.experiment.Experiment]
property
readonly
Returns the MLflow experiment object for the current step.
mlflow_experiment_name: str
property
readonly
Returns the MLflow experiment name for the current step.
mlflow_run: Optional[mlflow.tracking.fluent.ActiveRun]
property
readonly
Returns the MLflow run for the current step.
mlflow_run_name: str
property
readonly
Returns the MLflow run name for the current step.
__init__(self, experiment_name, run_name)
special
Initialize a MLflow step environment component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
experiment_name |
str |
the experiment name under which all MLflow artifacts logged under the current step will be tracked. If no MLflow experiment with this name exists, one will be created when the environment component is activated. |
required |
run_name |
str |
the name of the MLflow run associated with the current step. If a run with this name does not exist, one will be created when the environment component is activated, otherwise the existing run will be reused. |
required |
Source code in zenml/integrations/mlflow/mlflow_environment.py
def __init__(self, experiment_name: str, run_name: str):
"""Initialize a MLflow step environment component.
Args:
experiment_name: the experiment name under which all MLflow
artifacts logged under the current step will be tracked.
If no MLflow experiment with this name exists, one will
be created when the environment component is activated.
run_name: the name of the MLflow run associated with the current
step. If a run with this name does not exist, one will be
created when the environment component is activated,
otherwise the existing run will be reused.
"""
super().__init__()
self._experiment_name = experiment_name
self._experiment = None
self._run_name = run_name
self._run = None
activate(self)
Activate the MLflow environment for the current step.
Source code in zenml/integrations/mlflow/mlflow_environment.py
def activate(self) -> None:
"""Activate the MLflow environment for the current step."""
self._create_or_reuse_mlflow_run()
return super().activate()
deactivate(self)
Deactivate the environment component and deregister it from the global Environment.
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the component is not active. |
Source code in zenml/integrations/mlflow/mlflow_environment.py
def deactivate(self) -> None:
return super().deactivate()
mlflow_step_decorator
enable_mlflow(_step=None, *, experiment_name=None)
Decorator to enable mlflow for a step function.
Apply this decorator to a ZenML pipeline step to enable MLflow experiment
tracking. The MLflow tracking configuration (tracking URI, experiment name,
run name) will be automatically configured before the step code is executed,
so the step can simply use the mlflow
module to log metrics and artifacts,
like so:
@enable_mlflow
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
mlflow.log_metric("val_accuracy", test_acc)
return test_acc
All MLflow artifacts and metrics logged from all the steps in a pipeline
run are by default grouped under a single experiment named after the
pipeline. To log MLflow artifacts and metrics from a step in a separate
MLflow experiment, pass a custom experiment_name
argument value to the
decorator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_step |
Optional[~S] |
The decorated step class. |
None |
experiment_name |
Optional[str] |
optional mlflow experiment name to use for the step. If not provided, the name of the pipeline in the context of which the step is executed will be used as experiment name. |
None |
Returns:
Type | Description |
---|---|
Union[~S, Callable[[~S], ~S]] |
The inner decorator which enhaces the input step class with mlflow tracking functionality |
Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def enable_mlflow(
_step: Optional[S] = None,
*,
experiment_name: Optional[str] = None,
) -> Union[S, Callable[[S], S]]:
"""Decorator to enable mlflow for a step function.
Apply this decorator to a ZenML pipeline step to enable MLflow experiment
tracking. The MLflow tracking configuration (tracking URI, experiment name,
run name) will be automatically configured before the step code is executed,
so the step can simply use the `mlflow` module to log metrics and artifacts,
like so:
```python
@enable_mlflow
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
mlflow.log_metric("val_accuracy", test_acc)
return test_acc
```
All MLflow artifacts and metrics logged from all the steps in a pipeline
run are by default grouped under a single experiment named after the
pipeline. To log MLflow artifacts and metrics from a step in a separate
MLflow experiment, pass a custom `experiment_name` argument value to the
decorator.
Args:
_step: The decorated step class.
experiment_name: optional mlflow experiment name to use for the step.
If not provided, the name of the pipeline in the context of which
the step is executed will be used as experiment name.
Returns:
The inner decorator which enhaces the input step class with mlflow
tracking functionality
"""
def inner_decorator(_step: S) -> S:
logger.debug(
"Applying 'enable_mlflow' decorator to step %s", _step.__name__
)
if not issubclass(_step, BaseStep):
raise RuntimeError(
"The `enable_mlflow` decorator can only be applied to a ZenML "
"`step` decorated function or a BaseStep subclass."
)
source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
return cast(
S,
type( # noqa
_step.__name__,
(_step,),
{
STEP_INNER_FUNC_NAME: staticmethod(
mlflow_step_entrypoint(experiment_name)(source_fn)
),
"__module__": _step.__module__,
},
),
)
if _step is None:
return inner_decorator
else:
return inner_decorator(_step)
mlflow_step_entrypoint(experiment_name=None)
Decorator for a step entrypoint to enable mlflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
experiment_name |
Optional[str] |
optional mlflow experiment name to use for the step. If not provided, the name of the pipeline in the context of which the step is executed will be used as experiment name. |
None |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
the input function enhanced with mlflow profiling functionality |
Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def mlflow_step_entrypoint(
experiment_name: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator for a step entrypoint to enable mlflow.
Args:
experiment_name: optional mlflow experiment name to use for the step.
If not provided, the name of the pipeline in the context of which
the step is executed will be used as experiment name.
Returns:
the input function enhanced with mlflow profiling functionality
"""
def inner_decorator(func: F) -> F:
logger.debug(
"Applying 'mlflow_step_entrypoint' decorator to step entrypoint %s",
func.__name__,
)
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa
logger.debug(
"Setting up MLflow backend before running step entrypoint %s",
func.__name__,
)
step_env = Environment().step_environment
experiment = experiment_name or step_env.pipeline_name
step_mlflow_env = MLFlowStepEnvironment(
experiment_name=experiment, run_name=step_env.pipeline_run_id
)
with step_mlflow_env:
# should never happen, but just in case
assert step_mlflow_env.mlflow_run is not None
with step_mlflow_env.mlflow_run:
return func(*args, **kwargs)
return cast(F, wrapper)
return inner_decorator
services
special
mlflow_deployment
MLFlowDeploymentConfig (LocalDaemonServiceConfig)
pydantic-model
MLflow model deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the MLflow model to serve |
model_name |
str |
the name of the model |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
"""MLflow model deployment configuration.
Attributes:
model_uri: URI of the MLflow model to serve
model_name: the name of the model
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
"""
model_uri: str
model_name: str
workers: int = 1
mlserver: bool = False
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint)
pydantic-model
A service endpoint exposed by the MLflow deployment daemon.
Attributes:
Name | Type | Description |
---|---|---|
config |
MLFlowDeploymentEndpointConfig |
service endpoint configuration |
monitor |
HTTPEndpointHealthMonitor |
optional service endpoint health monitor |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
"""A service endpoint exposed by the MLflow deployment daemon.
Attributes:
config: service endpoint configuration
monitor: optional service endpoint health monitor
"""
config: MLFlowDeploymentEndpointConfig
monitor: HTTPEndpointHealthMonitor
@property
def prediction_uri(self) -> Optional[str]:
uri = self.status.uri
if not uri:
return None
return f"{uri}{self.config.prediction_uri_path}"
MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig)
pydantic-model
MLflow daemon service endpoint configuration.
Attributes:
Name | Type | Description |
---|---|---|
prediction_uri_path |
str |
URI subpath for prediction requests |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
"""MLflow daemon service endpoint configuration.
Attributes:
prediction_uri_path: URI subpath for prediction requests
"""
prediction_uri_path: str
MLFlowDeploymentService (LocalDaemonService)
pydantic-model
MLFlow deployment service that can be used to start a local prediction server for MLflow models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the MLflow deployment service class |
config |
MLFlowDeploymentConfig |
service configuration |
endpoint |
MLFlowDeploymentEndpoint |
optional service endpoint |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService):
"""MLFlow deployment service that can be used to start a local prediction
server for MLflow models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the MLflow deployment service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="mlflow-deployment",
type="model-serving",
flavor="mlflow",
description="MLflow prediction service",
)
config: MLFlowDeploymentConfig
endpoint: MLFlowDeploymentEndpoint
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_uri_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_uri_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_uri_path=prediction_uri_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
backend = PyFuncBackend(
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
)
backend.serve(
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
@property
def prediction_uri(self) -> Optional[str]:
"""Get the URI where the prediction service is answering requests.
Returns:
The URI where the prediction service can be contacted to process
HTTP/REST inference requests, or None, if the service isn't running.
"""
if not self.is_running:
return None
return self.endpoint.prediction_uri
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
response = requests.post(
self.endpoint.prediction_uri,
json={"instances": request.tolist()},
)
response.raise_for_status()
return np.array(response.json())
prediction_uri: Optional[str]
property
readonly
Get the URI where the prediction service is answering requests.
Returns:
Type | Description |
---|---|
Optional[str] |
The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running. |
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
NDArray[Any] |
a numpy array representing the request |
required |
Returns:
Type | Description |
---|---|
NDArray[Any] |
A numpy array representing the prediction returned by the service. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
response = requests.post(
self.endpoint.prediction_uri,
json={"instances": request.tolist()},
)
response.raise_for_status()
return np.array(response.json())
run(self)
Run the service daemon process associated with this service.
Subclasses must implement this method to provide the service daemon
functionality. This method will be executed in the context of the
running daemon, not in the context of the process that calls the
start
method.
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
backend = PyFuncBackend(
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
)
backend.serve(
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
steps
special
mlflow_deployer
MLFlowDeployerConfig (BaseStepConfig)
pydantic-model
MLflow model deployer step configuration
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
the name of the MLflow model logged in the MLflow artifact store for the current pipeline. |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
"""MLflow model deployer step configuration
Attributes:
model_name: the name of the MLflow model logged in the MLflow artifact
store for the current pipeline.
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
"""
model_name: str = "model"
workers: int = 1
mlserver: bool = False
mlflow_deployer_step(name=None, enable_cache=None)
Shortcut function to create a pipeline step to deploy a given ML model with a local MLflow prediction server.
The returned step can be used in a pipeline to implement continuous deployment for an MLflow model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
enable_cache |
Optional[bool] |
Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default |
None |
Returns:
Type | Description |
---|---|
Type[zenml.steps.base_step.BaseStep] |
an MLflow model deployer pipeline step |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
def mlflow_deployer_step(
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
) -> Type[BaseStep]:
"""Shortcut function to create a pipeline step to deploy a given ML model
with a local MLflow prediction server.
The returned step can be used in a pipeline to implement continuous
deployment for an MLflow model.
Args:
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default
Returns:
an MLflow model deployer pipeline step
"""
# enable cache explicitly to compensate for the fact that this step
# takes in a context object
if enable_cache is None:
enable_cache = True
@enable_mlflow
@step(enable_cache=enable_cache, name=name)
def mlflow_model_deployer(
deploy_decision: bool,
config: MLFlowDeployerConfig,
context: StepContext,
) -> MLFlowDeploymentService:
"""MLflow model deployer pipeline step
Args:
deploy_decision: whether to deploy the model or not
config: configuration for the deployer step
context: pipeline step context
Returns:
MLflow deployment service
"""
# Find a service created by a previous run of this step
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
last_service = cast(
MLFlowDeploymentService,
load_last_service_from_step(
pipeline_name=step_env.pipeline_name,
step_name=step_env.step_name,
step_context=context,
),
)
if last_service and not isinstance(
last_service, MLFlowDeploymentService
):
raise ValueError(
f"Last service deployed by step {step_env.step_name} and "
f"pipeline {step_env.pipeline_name} has invalid type. Expected "
f"MLFlowDeploymentService, found {type(last_service)}."
)
mlflow_step_env = cast(
MLFlowStepEnvironment, Environment()[MLFLOW_STEP_ENVIRONMENT_NAME]
)
client = MlflowClient()
# fetch the MLflow artifacts logged during the pipeline run
model_uri = None
mlflow_run = mlflow_step_env.mlflow_run
if mlflow_run and client.list_artifacts(
mlflow_run.info.run_id, config.model_name
):
model_uri = get_artifact_uri(config.model_name)
if not model_uri:
# an MLflow model was not found in the current run, so we simply reuse
# the service created during the previous step run
if not last_service:
raise RuntimeError(
f"An MLflow model with name `{config.model_name}` was not "
f"trained in the current pipeline run and no previous "
f"service was found."
)
return last_service
if not deploy_decision and last_service:
return last_service
# stop the service created during the last step run (will be replaced
# by a new one to serve the new model)
if last_service:
last_service.stop(timeout=10)
# create a new service for the new model
predictor_cfg = MLFlowDeploymentConfig(
model_name=config.model_name,
model_uri=model_uri,
workers=config.workers,
mlserver=config.mlserver,
)
service = MLFlowDeploymentService(predictor_cfg)
service.start(timeout=10)
return service
return mlflow_model_deployer
plotly
special
PlotlyIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/plotly/__init__.py
class PlotlyIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = PLOTLY
REQUIREMENTS = ["plotly>=5.4.0"]
visualizers
special
pipeline_lineage_visualizer
PipelineLineageVisualizer (BasePipelineVisualizer)
Visualize the lineage of runs in a pipeline using plotly.
Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
class PipelineLineageVisualizer(BasePipelineVisualizer):
"""Visualize the lineage of runs in a pipeline using plotly."""
@abstractmethod
def visualize(
self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
"""Creates a pipeline lineage diagram using plotly."""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
category_dict = {}
dimensions = ["run"]
for run in object.runs:
category_dict[run.name] = {"run": run.name}
for step in run.steps:
category_dict[run.name].update(
{
step.entrypoint_name: str(step.id),
}
)
if step.entrypoint_name not in dimensions:
dimensions.append(f"{step.entrypoint_name}")
category_df = pd.DataFrame.from_dict(category_dict, orient="index")
category_df = category_df.reset_index()
fig = px.parallel_categories(
category_df,
dimensions,
color=None,
labels="status",
)
fig.show()
return fig
visualize(self, object, *args, **kwargs)
Creates a pipeline lineage diagram using plotly.
Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
@abstractmethod
def visualize(
self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
"""Creates a pipeline lineage diagram using plotly."""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
category_dict = {}
dimensions = ["run"]
for run in object.runs:
category_dict[run.name] = {"run": run.name}
for step in run.steps:
category_dict[run.name].update(
{
step.entrypoint_name: str(step.id),
}
)
if step.entrypoint_name not in dimensions:
dimensions.append(f"{step.entrypoint_name}")
category_df = pd.DataFrame.from_dict(category_dict, orient="index")
category_df = category_df.reset_index()
fig = px.parallel_categories(
category_df,
dimensions,
color=None,
labels="status",
)
fig.show()
return fig
pytorch
special
PytorchIntegration (Integration)
Definition of PyTorch integration for ZenML.
Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
"""Definition of PyTorch integration for ZenML."""
NAME = PYTORCH
REQUIREMENTS = ["torch"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
materializers
special
pytorch_materializer
PyTorchMaterializer (BaseMaterializer)
Materializer to read/write Pytorch models.
Source code in zenml/integrations/pytorch/materializers/pytorch_materializer.py
class PyTorchMaterializer(BaseMaterializer):
"""Materializer to read/write Pytorch models."""
ASSOCIATED_TYPES = (Module, TorchDict)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Union[Module, TorchDict]:
"""Reads and returns a PyTorch model.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
return torch.load( # type: ignore[no-untyped-call]
os.path.join(self.artifact.uri, DEFAULT_FILENAME)
) # noqa
def handle_return(self, model: Union[Module, TorchDict]) -> None:
"""Writes a PyTorch model.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
handle_input(self, data_type)
Reads and returns a PyTorch model.
Returns:
Type | Description |
---|---|
Union[torch.nn.modules.module.Module, zenml.integrations.pytorch.materializers.pytorch_types.TorchDict] |
A loaded pytorch model. |
Source code in zenml/integrations/pytorch/materializers/pytorch_materializer.py
def handle_input(self, data_type: Type[Any]) -> Union[Module, TorchDict]:
"""Reads and returns a PyTorch model.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
return torch.load( # type: ignore[no-untyped-call]
os.path.join(self.artifact.uri, DEFAULT_FILENAME)
) # noqa
handle_return(self, model)
Writes a PyTorch model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[torch.nn.modules.module.Module, zenml.integrations.pytorch.materializers.pytorch_types.TorchDict] |
A torch.nn.Module or a dict to pass into model.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_materializer.py
def handle_return(self, model: Union[Module, TorchDict]) -> None:
"""Writes a PyTorch model.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
pytorch_types
TorchDict (dict, Generic)
A type of dict that represents saving a model.
Source code in zenml/integrations/pytorch/materializers/pytorch_types.py
class TorchDict(Dict[str, Any]):
"""A type of dict that represents saving a model."""
pytorch_lightning
special
PytorchLightningIntegration (Integration)
Definition of PyTorch Lightning integration for ZenML.
Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
"""Definition of PyTorch Lightning integration for ZenML."""
NAME = PYTORCH_L
REQUIREMENTS = ["pytorch_lightning"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
materializers
special
pytorch_lightning_materializer
PyTorchLightningMaterializer (BaseMaterializer)
Materializer to read/write Pytorch models.
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
"""Materializer to read/write Pytorch models."""
ASSOCIATED_TYPES = (Trainer,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Trainer:
"""Reads and returns a PyTorch Lightning trainer.
Returns:
A PyTorch Lightning trainer object.
"""
super().handle_input(data_type)
return Trainer(
resume_from_checkpoint=os.path.join(
self.artifact.uri, CHECKPOINT_NAME
)
)
def handle_return(self, trainer: Trainer) -> None:
"""Writes a PyTorch Lightning trainer.
Args:
trainer: A PyTorch Lightning trainer object.
"""
super().handle_return(trainer)
trainer.save_checkpoint(
os.path.join(self.artifact.uri, CHECKPOINT_NAME)
)
handle_input(self, data_type)
Reads and returns a PyTorch Lightning trainer.
Returns:
Type | Description |
---|---|
Trainer |
A PyTorch Lightning trainer object. |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_input(self, data_type: Type[Any]) -> Trainer:
"""Reads and returns a PyTorch Lightning trainer.
Returns:
A PyTorch Lightning trainer object.
"""
super().handle_input(data_type)
return Trainer(
resume_from_checkpoint=os.path.join(
self.artifact.uri, CHECKPOINT_NAME
)
)
handle_return(self, trainer)
Writes a PyTorch Lightning trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
Trainer |
A PyTorch Lightning trainer object. |
required |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_return(self, trainer: Trainer) -> None:
"""Writes a PyTorch Lightning trainer.
Args:
trainer: A PyTorch Lightning trainer object.
"""
super().handle_return(trainer)
trainer.save_checkpoint(
os.path.join(self.artifact.uri, CHECKPOINT_NAME)
)
registry
IntegrationRegistry
Registry to keep track of ZenML Integrations
Source code in zenml/integrations/registry.py
class IntegrationRegistry(object):
"""Registry to keep track of ZenML Integrations"""
def __init__(self) -> None:
"""Initializing the integration registry"""
self._integrations: Dict[str, Type["Integration"]] = {}
@property
def integrations(self) -> Dict[str, Type["Integration"]]:
"""Method to get integrations dictionary.
Returns:
A dict of integration key to type of `Integration`.
"""
return self._integrations
@integrations.setter
def integrations(self, i: Any) -> None:
"""Setter method for the integrations property"""
raise IntegrationError(
"Please do not manually change the integrations within the "
"registry. If you would like to register a new integration "
"manually, please use "
"`integration_registry.register_integration()`."
)
def register_integration(
self, key: str, type_: Type["Integration"]
) -> None:
"""Method to register an integration with a given name"""
self._integrations[key] = type_
def activate_integrations(self) -> None:
"""Method to activate the integrations with are registered in the
registry"""
for name, integration in self._integrations.items():
if integration.check_installation():
integration.activate()
logger.debug(f"Integration `{name}` is activated.")
else:
logger.debug(f"Integration `{name}` could not be activated.")
@property
def list_integration_names(self) -> List[str]:
"""Get a list of all possible integrations"""
return [name for name in self._integrations]
def select_integration_requirements(
self, integration_name: Optional[str] = None
) -> List[str]:
"""Select the requirements for a given integration
or all integrations"""
if integration_name:
if integration_name in self.list_integration_names:
return self._integrations[integration_name].REQUIREMENTS
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
else:
return [
requirement
for name in self.list_integration_names
for requirement in self._integrations[name].REQUIREMENTS
]
def is_installed(self, integration_name: Optional[str] = None) -> bool:
"""Checks if all requirements for an integration are installed"""
if integration_name in self.list_integration_names:
return self._integrations[integration_name].check_installation()
elif not integration_name:
all_installed = [
self._integrations[item].check_installation()
for item in self.list_integration_names
]
return all(all_installed)
else:
raise KeyError(
f"Integration '{integration_name}' not found. "
f"Currently the following integrations are available: "
f"{self.list_integration_names}"
)
integrations: Dict[str, Type[Integration]]
property
writable
Method to get integrations dictionary.
Returns:
Type | Description |
---|---|
Dict[str, Type[Integration]] |
A dict of integration key to type of |
list_integration_names: List[str]
property
readonly
Get a list of all possible integrations
__init__(self)
special
Initializing the integration registry
Source code in zenml/integrations/registry.py
def __init__(self) -> None:
"""Initializing the integration registry"""
self._integrations: Dict[str, Type["Integration"]] = {}
activate_integrations(self)
Method to activate the integrations with are registered in the registry
Source code in zenml/integrations/registry.py
def activate_integrations(self) -> None:
"""Method to activate the integrations with are registered in the
registry"""
for name, integration in self._integrations.items():
if integration.check_installation():
integration.activate()
logger.debug(f"Integration `{name}` is activated.")
else:
logger.debug(f"Integration `{name}` could not be activated.")
is_installed(self, integration_name=None)
Checks if all requirements for an integration are installed
Source code in zenml/integrations/registry.py
def is_installed(self, integration_name: Optional[str] = None) -> bool:
"""Checks if all requirements for an integration are installed"""
if integration_name in self.list_integration_names:
return self._integrations[integration_name].check_installation()
elif not integration_name:
all_installed = [
self._integrations[item].check_installation()
for item in self.list_integration_names
]
return all(all_installed)
else:
raise KeyError(
f"Integration '{integration_name}' not found. "
f"Currently the following integrations are available: "
f"{self.list_integration_names}"
)
register_integration(self, key, type_)
Method to register an integration with a given name
Source code in zenml/integrations/registry.py
def register_integration(
self, key: str, type_: Type["Integration"]
) -> None:
"""Method to register an integration with a given name"""
self._integrations[key] = type_
select_integration_requirements(self, integration_name=None)
Select the requirements for a given integration or all integrations
Source code in zenml/integrations/registry.py
def select_integration_requirements(
self, integration_name: Optional[str] = None
) -> List[str]:
"""Select the requirements for a given integration
or all integrations"""
if integration_name:
if integration_name in self.list_integration_names:
return self._integrations[integration_name].REQUIREMENTS
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
else:
return [
requirement
for name in self.list_integration_names
for requirement in self._integrations[name].REQUIREMENTS
]
sagemaker
special
The Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
SagemakerIntegration (Integration)
Definition of Sagemaker integration for ZenML.
Source code in zenml/integrations/sagemaker/__init__.py
class SagemakerIntegration(Integration):
"""Definition of Sagemaker integration for ZenML."""
NAME = SAGEMAKER
REQUIREMENTS = ["sagemaker==2.77.1"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sagemaker import step_operators # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/sagemaker/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sagemaker import step_operators # noqa
step_operators
special
sagemaker_step_operator
SagemakerStepOperator (BaseStepOperator)
pydantic-model
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to jobs running in Sagemaker. |
instance_type |
str |
The instance type of the compute where jobs will run. |
base_image |
Optional[str] |
[Optional] The base image to use for building the docker image that will be executed. |
bucket |
Optional[str] |
[Optional] Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
experiment_name |
Optional[str] |
[Optional] The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
Source code in zenml/integrations/sagemaker/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
Attributes:
role: The role that has to be assigned to jobs running in Sagemaker.
instance_type: The instance type of the compute where jobs will run.
base_image: [Optional] The base image to use for building the docker
image that will be executed.
bucket: [Optional] Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
experiment_name: [Optional] The name for the experiment to which the job
will be associated. If not provided, the job runs would be independent.
"""
supports_local_execution = True
supports_remote_execution = True
role: str
instance_type: str
base_image: Optional[str] = None
bucket: Optional[str] = None
experiment_name: Optional[str] = None
@property
def flavor(self) -> StepOperatorFlavor:
"""The step operator flavor."""
return StepOperatorFlavor.SAGEMAKER
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry."""
def _ensure_local_orchestrator(stack: Stack) -> bool:
return stack.orchestrator.flavor == OrchestratorFlavor.LOCAL
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_ensure_local_orchestrator,
)
def _build_docker_image(
self,
pipeline_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> str:
repo = Repository()
container_registry = repo.active_stack.container_registry
if not container_registry:
raise RuntimeError("Missing container registry")
registry_uri = container_registry.uri.rstrip("/")
image_name = f"{registry_uri}/zenml-sagemaker:{pipeline_name}"
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
entrypoint=" ".join(entrypoint_command),
requirements=set(requirements),
base_image=self.base_image,
)
docker_utils.push_docker_image(image_name)
return docker_utils.get_image_digest(image_name) or image_name
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Launches a step on Sagemaker.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
image_name = self._build_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
session = sagemaker.Session(default_bucket=self.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.role,
instance_count=1,
instance_type=self.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = run_name.replace("_", "-")
experiment_config = {}
if self.experiment_name:
experiment_config = {
"ExperimentName": self.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
flavor: StepOperatorFlavor
property
readonly
The step operator flavor.
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
launch(self, pipeline_name, run_name, requirements, entrypoint_command)
Launches a step on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
Source code in zenml/integrations/sagemaker/step_operators/sagemaker_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Launches a step on Sagemaker.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
image_name = self._build_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
session = sagemaker.Session(default_bucket=self.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.role,
instance_count=1,
instance_type=self.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = run_name.replace("_", "-")
experiment_config = {}
if self.experiment_name:
experiment_config = {
"ExperimentName": self.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
sklearn
special
SklearnIntegration (Integration)
Definition of sklearn integration for ZenML.
Source code in zenml/integrations/sklearn/__init__.py
class SklearnIntegration(Integration):
"""Definition of sklearn integration for ZenML."""
NAME = SKLEARN
REQUIREMENTS = ["scikit-learn"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/sklearn/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa
helpers
special
digits
get_digits()
Returns the digits dataset in the form of a tuple of numpy arrays.
Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits() -> Tuple[
"NDArray[np.float64]",
"NDArray[np.float64]",
"NDArray[np.int64]",
"NDArray[np.int64]",
]:
"""Returns the digits dataset in the form of a tuple of numpy
arrays."""
digits = load_digits()
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.5, shuffle=False
)
return X_train, X_test, y_train, y_test
get_digits_model()
Creates a support vector classifier for digits dataset.
Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits_model() -> ClassifierMixin:
"""Creates a support vector classifier for digits dataset."""
return SVC(gamma=0.001)
materializers
special
sklearn_materializer
SklearnMaterializer (BaseMaterializer)
Materializer to read data to and from sklearn.
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
class SklearnMaterializer(BaseMaterializer):
"""Materializer to read data to and from sklearn."""
ASSOCIATED_TYPES = (
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(
self, data_type: Type[Any]
) -> Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
]:
"""Reads a base sklearn model from a pickle file."""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "rb") as fid:
clf = pickle.load(fid)
return clf
def handle_return(
self,
clf: Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
],
) -> None:
"""Creates a pickle for a sklearn model.
Args:
clf: A sklearn model.
"""
super().handle_return(clf)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "wb") as fid:
pickle.dump(clf, fid)
handle_input(self, data_type)
Reads a base sklearn model from a pickle file.
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_input(
self, data_type: Type[Any]
) -> Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
]:
"""Reads a base sklearn model from a pickle file."""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "rb") as fid:
clf = pickle.load(fid)
return clf
handle_return(self, clf)
Creates a pickle for a sklearn model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clf |
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin] |
A sklearn model. |
required |
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_return(
self,
clf: Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
],
) -> None:
"""Creates a pickle for a sklearn model.
Args:
clf: A sklearn model.
"""
super().handle_return(clf)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "wb") as fid:
pickle.dump(clf, fid)
steps
special
sklearn_evaluator
SklearnEvaluator (BaseEvaluatorStep)
A simple step implementation which utilizes sklearn to evaluate the performance of a given model on a given test dataset
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluator(BaseEvaluatorStep):
"""A simple step implementation which utilizes sklearn to evaluate the
performance of a given model on a given test dataset"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: tf.keras.Model,
config: SklearnEvaluatorConfig,
) -> dict: # type: ignore[type-arg]
"""Method which is responsible for the computation of the evaluation
Args:
dataset: a pandas Dataframe which represents the test dataset
model: a trained tensorflow Keras model
config: the configuration for the step
Returns:
a dictionary which has the evaluation report
"""
labels = dataset.pop(config.label_class_column)
predictions = model.predict(dataset)
predicted_classes = [1 if v > 0.5 else 0 for v in predictions]
report = classification_report(
labels, predicted_classes, output_dict=True
)
return report # type: ignore[no-any-return]
CONFIG_CLASS (BaseEvaluatorConfig)
pydantic-model
Config class for the sklearn evaluator
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
"""Config class for the sklearn evaluator"""
label_class_column: str
entrypoint(self, dataset, model, config)
Method which is responsible for the computation of the evaluation
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a pandas Dataframe which represents the test dataset |
required |
model |
Model |
a trained tensorflow Keras model |
required |
config |
SklearnEvaluatorConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
dict |
a dictionary which has the evaluation report |
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: tf.keras.Model,
config: SklearnEvaluatorConfig,
) -> dict: # type: ignore[type-arg]
"""Method which is responsible for the computation of the evaluation
Args:
dataset: a pandas Dataframe which represents the test dataset
model: a trained tensorflow Keras model
config: the configuration for the step
Returns:
a dictionary which has the evaluation report
"""
labels = dataset.pop(config.label_class_column)
predictions = model.predict(dataset)
predicted_classes = [1 if v > 0.5 else 0 for v in predictions]
report = classification_report(
labels, predicted_classes, output_dict=True
)
return report # type: ignore[no-any-return]
SklearnEvaluatorConfig (BaseEvaluatorConfig)
pydantic-model
Config class for the sklearn evaluator
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
"""Config class for the sklearn evaluator"""
label_class_column: str
sklearn_splitter
SklearnSplitter (BaseSplitStep)
A simple step implementation which utilizes sklearn to split a given dataset into train, test and validation splits
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitter(BaseSplitStep):
"""A simple step implementation which utilizes sklearn to split a given
dataset into train, test and validation splits"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: SklearnSplitterConfig,
) -> Output( # type:ignore[valid-type]
train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
"""Method which is responsible for the splitting logic
Args:
dataset: a pandas Dataframe which entire dataset
config: the configuration for the step
Returns:
three dataframes representing the splits
"""
if (
any(
[
split not in config.ratios
for split in ["train", "test", "validation"]
]
)
or len(config.ratios) != 3
):
raise KeyError(
f"Make sure that you only use 'train', 'test' and "
f"'validation' as keys in the ratios dict. Current keys: "
f"{config.ratios.keys()}"
)
if sum(config.ratios.values()) != 1:
raise ValueError(
f"Make sure that the ratios sum up to 1. Current "
f"ratios: {config.ratios}"
)
train_dataset, test_dataset = train_test_split(
dataset, test_size=config.ratios["test"]
)
train_dataset, val_dataset = train_test_split(
train_dataset,
test_size=(
config.ratios["validation"]
/ (config.ratios["validation"] + config.ratios["train"])
),
)
return train_dataset, test_dataset, val_dataset
CONFIG_CLASS (BaseSplitStepConfig)
pydantic-model
Config class for the sklearn splitter
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
"""Config class for the sklearn splitter"""
ratios: Dict[str, float]
entrypoint(self, dataset, config)
Method which is responsible for the splitting logic
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a pandas Dataframe which entire dataset |
required |
config |
SklearnSplitterConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
<zenml.steps.step_output.Output object at 0x7fb8d1652eb0> |
three dataframes representing the splits |
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: SklearnSplitterConfig,
) -> Output( # type:ignore[valid-type]
train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
"""Method which is responsible for the splitting logic
Args:
dataset: a pandas Dataframe which entire dataset
config: the configuration for the step
Returns:
three dataframes representing the splits
"""
if (
any(
[
split not in config.ratios
for split in ["train", "test", "validation"]
]
)
or len(config.ratios) != 3
):
raise KeyError(
f"Make sure that you only use 'train', 'test' and "
f"'validation' as keys in the ratios dict. Current keys: "
f"{config.ratios.keys()}"
)
if sum(config.ratios.values()) != 1:
raise ValueError(
f"Make sure that the ratios sum up to 1. Current "
f"ratios: {config.ratios}"
)
train_dataset, test_dataset = train_test_split(
dataset, test_size=config.ratios["test"]
)
train_dataset, val_dataset = train_test_split(
train_dataset,
test_size=(
config.ratios["validation"]
/ (config.ratios["validation"] + config.ratios["train"])
),
)
return train_dataset, test_dataset, val_dataset
SklearnSplitterConfig (BaseSplitStepConfig)
pydantic-model
Config class for the sklearn splitter
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
"""Config class for the sklearn splitter"""
ratios: Dict[str, float]
sklearn_standard_scaler
SklearnStandardScaler (BasePreprocessorStep)
Simple step implementation which utilizes the StandardScaler from sklearn to transform the numeric columns of a pd.DataFrame
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScaler(BasePreprocessorStep):
"""Simple step implementation which utilizes the StandardScaler from sklearn
to transform the numeric columns of a pd.DataFrame"""
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
test_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
statistics: pd.DataFrame,
schema: pd.DataFrame,
config: SklearnStandardScalerConfig,
) -> Output( # type:ignore[valid-type]
train_transformed=pd.DataFrame,
test_transformed=pd.DataFrame,
validation_transformed=pd.DataFrame,
):
"""Main entrypoint function for the StandardScaler
Args:
train_dataset: pd.DataFrame, the training dataset
test_dataset: pd.DataFrame, the test dataset
validation_dataset: pd.DataFrame, the validation dataset
statistics: pd.DataFrame, the statistics over the train dataset
schema: pd.DataFrame, the detected schema of the dataset
config: the configuration for the step
Returns:
the transformed train, test and validation datasets as
pd.DataFrames
"""
schema_dict = {k: v[0] for k, v in schema.to_dict().items()}
# Exclude columns
feature_set = set(train_dataset.columns) - set(config.exclude_columns)
for feature, feature_type in schema_dict.items():
if feature_type != "int64" and feature_type != "float64":
feature_set.remove(feature)
logger.warning(
f"{feature} column is a not numeric, thus it is excluded "
f"from the standard scaling."
)
transform_feature_set = feature_set - set(config.ignore_columns)
# Transform the datasets
scaler = StandardScaler()
scaler.mean_ = statistics["mean"][transform_feature_set]
scaler.scale_ = statistics["std"][transform_feature_set]
train_dataset[list(transform_feature_set)] = scaler.transform(
train_dataset[transform_feature_set]
)
test_dataset[list(transform_feature_set)] = scaler.transform(
test_dataset[transform_feature_set]
)
validation_dataset[list(transform_feature_set)] = scaler.transform(
validation_dataset[transform_feature_set]
)
return train_dataset, test_dataset, validation_dataset
CONFIG_CLASS (BasePreprocessorConfig)
pydantic-model
Config class for the sklearn standard scaler
ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
"""Config class for the sklearn standard scaler
ignore_columns: a list of column names which should not be scaled
exclude_columns: a list of column names to be excluded from the dataset
"""
ignore_columns: List[str] = []
exclude_columns: List[str] = []
entrypoint(self, train_dataset, test_dataset, validation_dataset, statistics, schema, config)
Main entrypoint function for the StandardScaler
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_dataset |
DataFrame |
pd.DataFrame, the training dataset |
required |
test_dataset |
DataFrame |
pd.DataFrame, the test dataset |
required |
validation_dataset |
DataFrame |
pd.DataFrame, the validation dataset |
required |
statistics |
DataFrame |
pd.DataFrame, the statistics over the train dataset |
required |
schema |
DataFrame |
pd.DataFrame, the detected schema of the dataset |
required |
config |
SklearnStandardScalerConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
<zenml.steps.step_output.Output object at 0x7fb8d15b3700> |
the transformed train, test and validation datasets as pd.DataFrames |
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
test_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
statistics: pd.DataFrame,
schema: pd.DataFrame,
config: SklearnStandardScalerConfig,
) -> Output( # type:ignore[valid-type]
train_transformed=pd.DataFrame,
test_transformed=pd.DataFrame,
validation_transformed=pd.DataFrame,
):
"""Main entrypoint function for the StandardScaler
Args:
train_dataset: pd.DataFrame, the training dataset
test_dataset: pd.DataFrame, the test dataset
validation_dataset: pd.DataFrame, the validation dataset
statistics: pd.DataFrame, the statistics over the train dataset
schema: pd.DataFrame, the detected schema of the dataset
config: the configuration for the step
Returns:
the transformed train, test and validation datasets as
pd.DataFrames
"""
schema_dict = {k: v[0] for k, v in schema.to_dict().items()}
# Exclude columns
feature_set = set(train_dataset.columns) - set(config.exclude_columns)
for feature, feature_type in schema_dict.items():
if feature_type != "int64" and feature_type != "float64":
feature_set.remove(feature)
logger.warning(
f"{feature} column is a not numeric, thus it is excluded "
f"from the standard scaling."
)
transform_feature_set = feature_set - set(config.ignore_columns)
# Transform the datasets
scaler = StandardScaler()
scaler.mean_ = statistics["mean"][transform_feature_set]
scaler.scale_ = statistics["std"][transform_feature_set]
train_dataset[list(transform_feature_set)] = scaler.transform(
train_dataset[transform_feature_set]
)
test_dataset[list(transform_feature_set)] = scaler.transform(
test_dataset[transform_feature_set]
)
validation_dataset[list(transform_feature_set)] = scaler.transform(
validation_dataset[transform_feature_set]
)
return train_dataset, test_dataset, validation_dataset
SklearnStandardScalerConfig (BasePreprocessorConfig)
pydantic-model
Config class for the sklearn standard scaler
ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
"""Config class for the sklearn standard scaler
ignore_columns: a list of column names which should not be scaled
exclude_columns: a list of column names to be excluded from the dataset
"""
ignore_columns: List[str] = []
exclude_columns: List[str] = []
tensorflow
special
TensorflowIntegration (Integration)
Definition of Tensorflow integration for ZenML.
Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
"""Definition of Tensorflow integration for ZenML."""
NAME = TENSORFLOW
REQUIREMENTS = ["tensorflow==2.8.0", "tensorflow_io==0.24.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
# need to import this explicitly to load the Tensoflow file IO support
# for S3 and other file systems
import tensorflow_io # type: ignore [import]
from zenml.integrations.tensorflow import materializers # noqa
from zenml.integrations.tensorflow import services # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
# need to import this explicitly to load the Tensoflow file IO support
# for S3 and other file systems
import tensorflow_io # type: ignore [import]
from zenml.integrations.tensorflow import materializers # noqa
from zenml.integrations.tensorflow import services # noqa
materializers
special
keras_materializer
KerasMaterializer (BaseMaterializer)
Materializer to read/write Keras models.
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
"""Materializer to read/write Keras models."""
ASSOCIATED_TYPES = (keras.Model,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> keras.Model:
"""Reads and returns a Keras model.
Returns:
A tf.keras.Model model.
"""
super().handle_input(data_type)
return keras.models.load_model(self.artifact.uri)
def handle_return(self, model: keras.Model) -> None:
"""Writes a keras model.
Args:
model: A tf.keras.Model model.
"""
super().handle_return(model)
model.save(self.artifact.uri)
handle_input(self, data_type)
Reads and returns a Keras model.
Returns:
Type | Description |
---|---|
Model |
A tf.keras.Model model. |
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_input(self, data_type: Type[Any]) -> keras.Model:
"""Reads and returns a Keras model.
Returns:
A tf.keras.Model model.
"""
super().handle_input(data_type)
return keras.models.load_model(self.artifact.uri)
handle_return(self, model)
Writes a keras model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Model |
A tf.keras.Model model. |
required |
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_return(self, model: keras.Model) -> None:
"""Writes a keras model.
Args:
model: A tf.keras.Model model.
"""
super().handle_return(model)
model.save(self.artifact.uri)
tf_dataset_materializer
TensorflowDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from tf.data.Dataset.
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from tf.data.Dataset."""
ASSOCIATED_TYPES = (tf.data.Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Any:
"""Reads data into tf.data.Dataset"""
super().handle_input(data_type)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
return tf.data.experimental.load(path)
def handle_return(self, dataset: tf.data.Dataset) -> None:
"""Persists a tf.data.Dataset object."""
super().handle_return(dataset)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
tf.data.experimental.save(
dataset, path, compression=None, shard_func=None
)
handle_input(self, data_type)
Reads data into tf.data.Dataset
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Any:
"""Reads data into tf.data.Dataset"""
super().handle_input(data_type)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
return tf.data.experimental.load(path)
handle_return(self, dataset)
Persists a tf.data.Dataset object.
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_return(self, dataset: tf.data.Dataset) -> None:
"""Persists a tf.data.Dataset object."""
super().handle_return(dataset)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
tf.data.experimental.save(
dataset, path, compression=None, shard_func=None
)
services
special
tensorboard_service
TensorboardService (LocalDaemonService)
pydantic-model
Tensorboard service that can be used to start a local Tensorboard server for one or more models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the Tensorboard service class |
config |
TensorboardServiceConfig |
service configuration |
endpoint |
LocalDaemonServiceEndpoint |
optional service endpoint |
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardService(LocalDaemonService):
"""Tensorboard service that can be used to start a local Tensorboard server
for one or more models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the Tensorboard service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="tensorboard",
type="visualization",
flavor="tensorboard",
description="Tensorboard visualization service",
)
config: TensorboardServiceConfig
endpoint: LocalDaemonServiceEndpoint
def __init__(
self,
config: Union[TensorboardServiceConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-697]: implement a service factory or builder for Tensorboard
# deployment services
if (
isinstance(config, TensorboardServiceConfig)
and "endpoint" not in attrs
):
endpoint = LocalDaemonServiceEndpoint(
config=LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path="",
use_head_request=True,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
logger.info(
"Starting Tensorboard service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
tensorboard = program.TensorBoard(
plugins=default.get_plugins(),
subcommands=[uploader_subcommand.UploaderSubcommand()],
)
tensorboard.configure(
logdir=self.config.logdir,
port=self.endpoint.status.port,
host="localhost",
max_reload_threads=self.config.max_reload_threads,
reload_interval=self.config.reload_interval,
)
tensorboard.main()
except KeyboardInterrupt:
logger.info(
"Tensorboard service stopped. Resuming normal execution."
)
run(self)
Run the service daemon process associated with this service.
Subclasses must implement this method to provide the service daemon
functionality. This method will be executed in the context of the
running daemon, not in the context of the process that calls the
start
method.
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def run(self) -> None:
logger.info(
"Starting Tensorboard service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
tensorboard = program.TensorBoard(
plugins=default.get_plugins(),
subcommands=[uploader_subcommand.UploaderSubcommand()],
)
tensorboard.configure(
logdir=self.config.logdir,
port=self.endpoint.status.port,
host="localhost",
max_reload_threads=self.config.max_reload_threads,
reload_interval=self.config.reload_interval,
)
tensorboard.main()
except KeyboardInterrupt:
logger.info(
"Tensorboard service stopped. Resuming normal execution."
)
TensorboardServiceConfig (LocalDaemonServiceConfig)
pydantic-model
Tensorboard service configuration.
Attributes:
Name | Type | Description |
---|---|---|
logdir |
str |
location of Tensorboard log files. |
max_reload_threads |
int |
the max number of threads that TensorBoard can use to reload runs. Each thread reloads one run at a time. |
reload_interval |
int |
how often the backend should load more data, in seconds. Set to 0 to load just once at startup. |
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardServiceConfig(LocalDaemonServiceConfig):
"""Tensorboard service configuration.
Attributes:
logdir: location of Tensorboard log files.
max_reload_threads: the max number of threads that TensorBoard can use
to reload runs. Each thread reloads one run at a time.
reload_interval: how often the backend should load more data, in
seconds. Set to 0 to load just once at startup.
"""
logdir: str
max_reload_threads: int = 1
reload_interval: int = 5
steps
special
tensorflow_trainer
TensorflowBinaryClassifier (BaseTrainerStep)
Simple step implementation which creates a simple tensorflow feedforward neural network and trains it on a given pd.DataFrame dataset
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifier(BaseTrainerStep):
"""Simple step implementation which creates a simple tensorflow feedforward
neural network and trains it on a given pd.DataFrame dataset
"""
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
"""Main entrypoint for the tensorflow trainer
Args:
train_dataset: pd.DataFrame, the training dataset
validation_dataset: pd.DataFrame, the validation dataset
config: the configuration of the step
Returns:
the trained tf.keras.Model
"""
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
model.add(tf.keras.layers.Flatten())
last_layer = config.layers.pop()
for i, layer in enumerate(config.layers):
model.add(tf.keras.layers.Dense(layer, activation="relu"))
model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))
model.compile(
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=config.metrics,
)
train_target = train_dataset.pop(config.target_column)
validation_target = validation_dataset.pop(config.target_column)
model.fit(
x=train_dataset,
y=train_target,
validation_data=(validation_dataset, validation_target),
batch_size=config.batch_size,
epochs=config.epochs,
)
model.summary()
return model
CONFIG_CLASS (BaseTrainerConfig)
pydantic-model
Config class for the tensorflow trainer
target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
"""Config class for the tensorflow trainer
target_column: the name of the label column
layers: the number of units in the fully connected layers
input_shape: the shape of the input
learning_rate: the learning rate
metrics: the list of metrics to be computed
epochs: the number of epochs
batch_size: the size of the batch
"""
target_column: str
layers: List[int] = [256, 64, 1]
input_shape: Tuple[int] = (8,)
learning_rate: float = 0.001
metrics: List[str] = ["accuracy"]
epochs: int = 50
batch_size: int = 8
entrypoint(self, train_dataset, validation_dataset, config)
Main entrypoint for the tensorflow trainer
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_dataset |
DataFrame |
pd.DataFrame, the training dataset |
required |
validation_dataset |
DataFrame |
pd.DataFrame, the validation dataset |
required |
config |
TensorflowBinaryClassifierConfig |
the configuration of the step |
required |
Returns:
Type | Description |
---|---|
Model |
the trained tf.keras.Model |
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
"""Main entrypoint for the tensorflow trainer
Args:
train_dataset: pd.DataFrame, the training dataset
validation_dataset: pd.DataFrame, the validation dataset
config: the configuration of the step
Returns:
the trained tf.keras.Model
"""
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
model.add(tf.keras.layers.Flatten())
last_layer = config.layers.pop()
for i, layer in enumerate(config.layers):
model.add(tf.keras.layers.Dense(layer, activation="relu"))
model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))
model.compile(
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=config.metrics,
)
train_target = train_dataset.pop(config.target_column)
validation_target = validation_dataset.pop(config.target_column)
model.fit(
x=train_dataset,
y=train_target,
validation_data=(validation_dataset, validation_target),
batch_size=config.batch_size,
epochs=config.epochs,
)
model.summary()
return model
TensorflowBinaryClassifierConfig (BaseTrainerConfig)
pydantic-model
Config class for the tensorflow trainer
target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
"""Config class for the tensorflow trainer
target_column: the name of the label column
layers: the number of units in the fully connected layers
input_shape: the shape of the input
learning_rate: the learning rate
metrics: the list of metrics to be computed
epochs: the number of epochs
batch_size: the size of the batch
"""
target_column: str
layers: List[int] = [256, 64, 1]
input_shape: Tuple[int] = (8,)
learning_rate: float = 0.001
metrics: List[str] = ["accuracy"]
epochs: int = 50
batch_size: int = 8
visualizers
special
tensorboard_visualizer
TensorboardVisualizer (BaseStepVisualizer)
The implementation of a Whylogs Visualizer.
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
class TensorboardVisualizer(BaseStepVisualizer):
"""The implementation of a Whylogs Visualizer."""
@classmethod
def find_running_tensorboard_server(
cls, logdir: str
) -> Optional[TensorBoardInfo]:
"""Find a local Tensorboard server instance running for the supplied
logdir location and return its TCP port.
Returns:
The TensorBoardInfo describing the running Tensorboard server or
None if no server is running for the supplied logdir location.
"""
for server in get_all():
if (
server.logdir == logdir
and server.pid
and psutil.pid_exists(server.pid)
):
return server
return None
def visualize(
self,
object: StepView,
*args: Any,
height: int = 800,
**kwargs: Any,
) -> None:
"""Start a Tensorboard server to visualize all models logged as
artifacts by the indicated step. The server will monitor and display
all the models logged by past and future step runs.
Args:
object: StepView fetched from run.get_step().
height: Height of the generated visualization.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a Tensorboard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return
if sys.platform == "win32":
# Daemon service functionality is currently not supported on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the Tensorboard logs for your trained model."
)
else:
# start a new Tensorboard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
)
)
service.start(timeout=20)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return
def visualize_tensorboard(
self,
port: int,
height: int,
) -> None:
"""Generate a visualization of a Tensorboard.
Args:
port: the TCP port where the Tensorboard server is listening for
requests.
height: Height of the generated visualization.
logdir: The logdir location for the Tensorboard server.
"""
if Environment.in_notebook():
notebook.display(port, height=height)
return
print(
"You can visit:\n"
f"[italic green] http://localhost:{port}/[/italic green]\n"
"...to visualize the Tensorboard logs for your trained model."
)
def stop(
self,
object: StepView,
) -> None:
"""Stop the Tensorboard server previously started for a pipeline step.
Args:
object: StepView fetched from run.get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a Tensorboard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
find_running_tensorboard_server(logdir)
classmethod
Find a local Tensorboard server instance running for the supplied logdir location and return its TCP port.
Returns:
Type | Description |
---|---|
Optional[tensorboard.manager.TensorBoardInfo] |
The TensorBoardInfo describing the running Tensorboard server or None if no server is running for the supplied logdir location. |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
@classmethod
def find_running_tensorboard_server(
cls, logdir: str
) -> Optional[TensorBoardInfo]:
"""Find a local Tensorboard server instance running for the supplied
logdir location and return its TCP port.
Returns:
The TensorBoardInfo describing the running Tensorboard server or
None if no server is running for the supplied logdir location.
"""
for server in get_all():
if (
server.logdir == logdir
and server.pid
and psutil.pid_exists(server.pid)
):
return server
return None
stop(self, object)
Stop the Tensorboard server previously started for a pipeline step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop(
self,
object: StepView,
) -> None:
"""Stop the Tensorboard server previously started for a pipeline step.
Args:
object: StepView fetched from run.get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a Tensorboard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
visualize(self, object, *args, *, height=800, **kwargs)
Start a Tensorboard server to visualize all models logged as artifacts by the indicated step. The server will monitor and display all the models logged by past and future step runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
height |
int |
Height of the generated visualization. |
800 |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize(
self,
object: StepView,
*args: Any,
height: int = 800,
**kwargs: Any,
) -> None:
"""Start a Tensorboard server to visualize all models logged as
artifacts by the indicated step. The server will monitor and display
all the models logged by past and future step runs.
Args:
object: StepView fetched from run.get_step().
height: Height of the generated visualization.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a Tensorboard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return
if sys.platform == "win32":
# Daemon service functionality is currently not supported on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the Tensorboard logs for your trained model."
)
else:
# start a new Tensorboard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
)
)
service.start(timeout=20)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return
visualize_tensorboard(self, port, height)
Generate a visualization of a Tensorboard.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
port |
int |
the TCP port where the Tensorboard server is listening for requests. |
required |
height |
int |
Height of the generated visualization. |
required |
logdir |
The logdir location for the Tensorboard server. |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(
self,
port: int,
height: int,
) -> None:
"""Generate a visualization of a Tensorboard.
Args:
port: the TCP port where the Tensorboard server is listening for
requests.
height: Height of the generated visualization.
logdir: The logdir location for the Tensorboard server.
"""
if Environment.in_notebook():
notebook.display(port, height=height)
return
print(
"You can visit:\n"
f"[italic green] http://localhost:{port}/[/italic green]\n"
"...to visualize the Tensorboard logs for your trained model."
)
get_step(pipeline_name, step_name)
Get the StepView for the specified pipeline and step name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
step_name |
str |
The name of the step. |
required |
Returns:
Type | Description |
---|---|
StepView |
The StepView for the specified pipeline and step name. |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def get_step(pipeline_name: str, step_name: str) -> StepView:
"""Get the StepView for the specified pipeline and step name.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
Returns:
The StepView for the specified pipeline and step name.
"""
repo = Repository()
pipeline = repo.get_pipeline(pipeline_name)
if pipeline is None:
raise RuntimeError(f"No pipeline with name `{pipeline_name}` was found")
last_run = pipeline.runs[-1]
step = last_run.get_step(name=step_name)
if step is None:
raise RuntimeError(
f"No pipeline step with name `{step_name}` was found in "
f"pipeline `{pipeline_name}`"
)
return cast(StepView, step)
stop_tensorboard_server(pipeline_name, step_name)
Stop the Tensorboard server previously started for a pipeline step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the pipeline |
required |
step_name |
str |
pipeline step name |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop_tensorboard_server(pipeline_name: str, step_name: str) -> None:
"""Stop the Tensorboard server previously started for a pipeline step.
Args:
pipeline_name: the name of the pipeline
step_name: pipeline step name
"""
step = get_step(pipeline_name, step_name)
TensorboardVisualizer().stop(step)
visualize_tensorboard(pipeline_name, step_name)
Start a Tensorboard server to visualize all models logged as output by the named pipeline step. The server will monitor and display all the models logged by past and future step runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the pipeline |
required |
step_name |
str |
pipeline step name |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(pipeline_name: str, step_name: str) -> None:
"""Start a Tensorboard server to visualize all models logged as output by
the named pipeline step. The server will monitor and display all the models
logged by past and future step runs.
Args:
pipeline_name: the name of the pipeline
step_name: pipeline step name
"""
step = get_step(pipeline_name, step_name)
TensorboardVisualizer().visualize(step)
utils
get_integration_for_module(module_name)
Gets the integration class for a module inside an integration.
If the module given by module_name
is not part of a ZenML integration,
this method will return None
. If it is part of a ZenML integration,
it will return the integration class found inside the integration
init file.
Source code in zenml/integrations/utils.py
def get_integration_for_module(module_name: str) -> Optional[Type[Integration]]:
"""Gets the integration class for a module inside an integration.
If the module given by `module_name` is not part of a ZenML integration,
this method will return `None`. If it is part of a ZenML integration,
it will return the integration class found inside the integration
__init__ file.
"""
integration_prefix = "zenml.integrations."
if not module_name.startswith(integration_prefix):
return None
integration_module_name = ".".join(module_name.split(".", 3)[:3])
try:
integration_module = sys.modules[integration_module_name]
except KeyError:
integration_module = importlib.import_module(integration_module_name)
for name, member in inspect.getmembers(integration_module):
if (
member is not Integration
and isinstance(member, IntegrationMeta)
and issubclass(member, Integration)
):
return cast(Type[Integration], member)
return None
get_requirements_for_module(module_name)
Gets requirements for a module inside an integration.
If the module given by module_name
is not part of a ZenML integration,
this method will return an empty list. If it is part of a ZenML integration,
it will return the list of requirements specified inside the integration
class found inside the integration init file.
Source code in zenml/integrations/utils.py
def get_requirements_for_module(module_name: str) -> List[str]:
"""Gets requirements for a module inside an integration.
If the module given by `module_name` is not part of a ZenML integration,
this method will return an empty list. If it is part of a ZenML integration,
it will return the list of requirements specified inside the integration
class found inside the integration __init__ file.
"""
integration = get_integration_for_module(module_name)
return integration.REQUIREMENTS if integration else []
vertex
special
The Vertex integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.
VertexIntegration (Integration)
Definition of Vertex AI integration for ZenML.
Source code in zenml/integrations/vertex/__init__.py
class VertexIntegration(Integration):
"""Definition of Vertex AI integration for ZenML."""
NAME = VERTEX
REQUIREMENTS = ["google-cloud-aiplatform>=1.11.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.vertex import step_operators # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/vertex/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.vertex import step_operators # noqa
orchestrator
special
vertex_ai_orchestrator
VertexOrchestrator (KubeflowOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines on Vertex AI.
Source code in zenml/integrations/vertex/orchestrator/vertex_ai_orchestrator.py
class VertexOrchestrator(KubeflowOrchestrator):
"""Orchestrator responsible for running pipelines on Vertex AI."""
supports_local_execution = False
supports_remote_execution = True
@property
def flavor(self) -> OrchestratorFlavor:
"""The orchestrator flavor."""
return OrchestratorFlavor.VERTEX
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline on Vertex AI using the Kubeflow orchestrator."""
raise NotImplementedError("Vertex AI orchestration is coming soon!")
# super().run_pipeline(pipeline, stack, runtime_configuration)
# aiplatform.init(
# project=GOOGLE_CLOUD_PROJECT, location=GOOGLE_CLOUD_REGION
# )
# job = pipeline_jobs.PipelineJob(
# template_path=PIPELINE_DEFINITION_FILE, display_name=PIPELINE_NAME
# )
# job.submit()
flavor: OrchestratorFlavor
property
readonly
The orchestrator flavor.
run_pipeline(self, pipeline, stack, runtime_configuration)
Runs a pipeline on Vertex AI using the Kubeflow orchestrator.
Source code in zenml/integrations/vertex/orchestrator/vertex_ai_orchestrator.py
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline on Vertex AI using the Kubeflow orchestrator."""
raise NotImplementedError("Vertex AI orchestration is coming soon!")
# super().run_pipeline(pipeline, stack, runtime_configuration)
# aiplatform.init(
# project=GOOGLE_CLOUD_PROJECT, location=GOOGLE_CLOUD_REGION
# )
# job = pipeline_jobs.PipelineJob(
# template_path=PIPELINE_DEFINITION_FILE, display_name=PIPELINE_NAME
# )
# job.submit()
step_operators
special
vertex_step_operator
Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py
VertexStepOperator (BaseStepOperator)
pydantic-model
Step operator to run a step on Vertex AI.
This class defines code that can setup a Vertex AI environment and run the ZenML entrypoint command in it.
Attributes:
Name | Type | Description |
---|---|---|
region |
str |
Region name, e.g., |
project |
Optional[str] |
[Optional] GCP project name. If left None, inferred from the environment. |
accelerator_type |
Optional[str] |
[Optional] Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType |
accelerator_count |
int |
[Optional] Defines number of accelerators to be used for the job. |
machine_type |
str |
[Optional] Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types |
base_image |
Optional[str] |
[Optional] Base image for building the custom job container. |
encryption_spec_key_name |
Optional[str] |
[Optional]: Encryption spec key name. |
service_account_path |
Optional[str] |
[Optional]: Path to service account file specifiying credentials of the GCP user. If not provided, falls back |
Source code in zenml/integrations/vertex/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator):
"""Step operator to run a step on Vertex AI.
This class defines code that can setup a Vertex AI environment and run the
ZenML entrypoint command in it.
Attributes:
region: Region name, e.g., `europe-west1`.
project: [Optional] GCP project name. If left None, inferred from the environment.
accelerator_type: [Optional] Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType
accelerator_count: [Optional] Defines number of accelerators to be used for the job.
machine_type: [Optional] Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
base_image: [Optional] Base image for building the custom job container.
encryption_spec_key_name: [Optional]: Encryption spec key name.
service_account_path: [Optional]: Path to service account file specifiying credentials of the GCP user. If not provided, falls back
to Default Credentials.
"""
supports_local_execution = True
supports_remote_execution = True
region: str
project: Optional[str] = None
accelerator_type: Optional[str] = None
accelerator_count: int = 0
machine_type: str = "n1-standard-4"
base_image: Optional[str] = None
# customer managed encryption key resource name
# will be applied to all Vertex AI resources if set
encryption_spec_key_name: Optional[str] = None
# path to google service account
# environment default credentials used if not set
service_account_path: Optional[str] = None
@property
def flavor(self) -> StepOperatorFlavor:
"""The step operator flavor."""
return StepOperatorFlavor.VERTEX
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry."""
def _ensure_local_orchestrator(stack: Stack) -> bool:
# For now this only works on local orchestrator and GCP artifact
# store
return (stack.orchestrator.flavor == OrchestratorFlavor.LOCAL) and (
stack.artifact_store.flavor == ArtifactStoreFlavor.GCP
)
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_ensure_local_orchestrator,
)
@property_validator("accelerator_type")
def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
accepted_vals = list(
aiplatform.gapic.AcceleratorType.__members__.keys()
)
if accelerator_type and accelerator_type.upper() not in accepted_vals:
raise RuntimeError(
f"Accelerator must be one of the following: {accepted_vals}"
)
def _get_authentication(
self,
) -> Tuple[Optional[auth_credentials.Credentials], Optional[str]]:
if self.service_account_path:
credentials, project_id = load_credentials_from_file(
self.service_account_path
)
else:
credentials, project_id = default()
return credentials, project_id
def _build_and_push_docker_image(
self,
pipeline_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> str:
repo = Repository()
container_registry = repo.active_stack.container_registry
if not container_registry:
raise RuntimeError("Missing container registry")
registry_uri = container_registry.uri.rstrip("/")
image_name = f"{registry_uri}/zenml-vertex:{pipeline_name}"
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
entrypoint=" ".join(entrypoint_command),
requirements=set(requirements),
base_image=self.base_image,
)
docker_utils.push_docker_image(image_name)
return docker_utils.get_image_digest(image_name) or image_name
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Launches a step on Vertex AI.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
if self.project:
if self.project != project_id:
logger.warning(
f"Authenticated with project {project_id}, but this "
f"operator is configured to use project {self.project}."
)
else:
self.project = project_id
# Step 2: Build and push image
image_name = self._build_and_push_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
custom_job = {
"display_name": run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": self.machine_type,
"accelerator_type": self.accelerator_type,
"accelerator_count": self.accelerator_count
if self.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": [],
"args": [],
},
}
]
},
"labels": job_labels,
"encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
if self.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = f"projects/{self.project}/locations/{self.region}"
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state periodically,
# and retries the polling when a transient connectivity issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELLED are considered to be final states.
# The following logic will keep polling the state of the job until the job
# enters a final state.
#
# During the polling, if a connection error was encountered, the GET request
# will be retried by recreating the Python API client to refresh the lifecycle
# of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function will raise
# ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: %s. Trying to "
"recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
flavor: StepOperatorFlavor
property
readonly
The step operator flavor.
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
launch(self, pipeline_name, run_name, requirements, entrypoint_command)
Launches a step on Vertex AI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
Source code in zenml/integrations/vertex/step_operators/vertex_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Launches a step on Vertex AI.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
if self.project:
if self.project != project_id:
logger.warning(
f"Authenticated with project {project_id}, but this "
f"operator is configured to use project {self.project}."
)
else:
self.project = project_id
# Step 2: Build and push image
image_name = self._build_and_push_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
custom_job = {
"display_name": run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": self.machine_type,
"accelerator_type": self.accelerator_type,
"accelerator_count": self.accelerator_count
if self.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": [],
"args": [],
},
}
]
},
"labels": job_labels,
"encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
if self.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = f"projects/{self.project}/locations/{self.region}"
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state periodically,
# and retries the polling when a transient connectivity issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELLED are considered to be final states.
# The following logic will keep polling the state of the job until the job
# enters a final state.
#
# During the polling, if a connection error was encountered, the GET request
# will be retried by recreating the Python API client to refresh the lifecycle
# of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function will raise
# ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: %s. Trying to "
"recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
whylogs
special
WhylogsIntegration (Integration)
Definition of whylogs integration for ZenML.
Source code in zenml/integrations/whylogs/__init__.py
class WhylogsIntegration(Integration):
"""Definition of [whylogs](https://github.com/whylabs/whylogs)
integration for ZenML."""
NAME = WHYLOGS
REQUIREMENTS = ["whylogs>=0.6.22", "pybars3>=0.9.7"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.whylogs import materializers # noqa
from zenml.integrations.whylogs import visualizers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.whylogs import materializers # noqa
from zenml.integrations.whylogs import visualizers # noqa
materializers
special
whylogs_materializer
WhylogsMaterializer (BaseMaterializer)
Materializer to read/write whylogs dataset profiles.
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
class WhylogsMaterializer(BaseMaterializer):
"""Materializer to read/write whylogs dataset profiles."""
ASSOCIATED_TYPES = (DatasetProfile,)
ASSOCIATED_ARTIFACT_TYPES = (StatisticsArtifact,)
def handle_input(self, data_type: Type[Any]) -> DatasetProfile:
"""Reads and returns a whylogs DatasetProfile.
Returns:
A loaded whylogs DatasetProfile.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
with fileio.open(filepath, "rb") as f:
protobuf = DatasetProfile.parse_delimited(f.read())[0]
return protobuf
def handle_return(self, profile: DatasetProfile) -> None:
"""Writes a whylogs DatasetProfile.
Args:
profile: A DatasetProfile object from whylogs.
"""
super().handle_return(profile)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
protobuf = profile.serialize_delimited()
with fileio.open(filepath, "wb") as f:
f.write(protobuf)
# TODO [ENG-439]: uploading profiles to whylabs should be enabled and
# configurable at step level or pipeline level instead of being
# globally enabled.
if os.environ.get("WHYLABS_DEFAULT_ORG_ID"):
upload_profile(profile)
handle_input(self, data_type)
Reads and returns a whylogs DatasetProfile.
Returns:
Type | Description |
---|---|
DatasetProfile |
A loaded whylogs DatasetProfile. |
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_input(self, data_type: Type[Any]) -> DatasetProfile:
"""Reads and returns a whylogs DatasetProfile.
Returns:
A loaded whylogs DatasetProfile.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
with fileio.open(filepath, "rb") as f:
protobuf = DatasetProfile.parse_delimited(f.read())[0]
return protobuf
handle_return(self, profile)
Writes a whylogs DatasetProfile.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile |
DatasetProfile |
A DatasetProfile object from whylogs. |
required |
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_return(self, profile: DatasetProfile) -> None:
"""Writes a whylogs DatasetProfile.
Args:
profile: A DatasetProfile object from whylogs.
"""
super().handle_return(profile)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
protobuf = profile.serialize_delimited()
with fileio.open(filepath, "wb") as f:
f.write(protobuf)
# TODO [ENG-439]: uploading profiles to whylabs should be enabled and
# configurable at step level or pipeline level instead of being
# globally enabled.
if os.environ.get("WHYLABS_DEFAULT_ORG_ID"):
upload_profile(profile)
steps
special
whylogs_profiler
WhylogsProfilerConfig (BaseAnalyzerConfig)
pydantic-model
Config class for the WhylogsProfiler step.
Attributes:
Name | Type | Description |
---|---|---|
dataset_name |
Optional[str] |
the name of the dataset (Optional). If not specified, the pipeline step name is used |
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
tags |
Optional[Dict[str, str]] |
custom metadata tags associated with the whylogs profile |
Also see WhylogsContext.log_dataframe
.
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
"""Config class for the WhylogsProfiler step.
Attributes:
dataset_name: the name of the dataset (Optional). If not specified,
the pipeline step name is used
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
tags: custom metadata tags associated with the whylogs profile
Also see `WhylogsContext.log_dataframe`.
"""
dataset_name: Optional[str] = None
dataset_timestamp: Optional[datetime.datetime]
tags: Optional[Dict[str, str]] = None
WhylogsProfilerStep (BaseAnalyzerStep)
Simple step implementation which generates a whylogs data profile from a a given pd.DataFrame
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerStep(BaseAnalyzerStep):
"""Simple step implementation which generates a whylogs data profile from a
a given pd.DataFrame"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: WhylogsProfilerConfig,
context: StepContext,
) -> DatasetProfile:
"""Main entrypoint function for the whylogs profiler
Args:
dataset: pd.DataFrame, the given dataset
config: the configuration of the step
context: the context of the step
Returns:
whylogs profile with statistics generated for the input dataset
"""
whylogs_context = WhylogsContext(context)
profile = whylogs_context.profile_dataframe(
dataset, dataset_name=config.dataset_name, tags=config.tags
)
return profile
CONFIG_CLASS (BaseAnalyzerConfig)
pydantic-model
Config class for the WhylogsProfiler step.
Attributes:
Name | Type | Description |
---|---|---|
dataset_name |
Optional[str] |
the name of the dataset (Optional). If not specified, the pipeline step name is used |
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
tags |
Optional[Dict[str, str]] |
custom metadata tags associated with the whylogs profile |
Also see WhylogsContext.log_dataframe
.
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
"""Config class for the WhylogsProfiler step.
Attributes:
dataset_name: the name of the dataset (Optional). If not specified,
the pipeline step name is used
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
tags: custom metadata tags associated with the whylogs profile
Also see `WhylogsContext.log_dataframe`.
"""
dataset_name: Optional[str] = None
dataset_timestamp: Optional[datetime.datetime]
tags: Optional[Dict[str, str]] = None
entrypoint(self, dataset, config, context)
Main entrypoint function for the whylogs profiler
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
pd.DataFrame, the given dataset |
required |
config |
WhylogsProfilerConfig |
the configuration of the step |
required |
context |
StepContext |
the context of the step |
required |
Returns:
Type | Description |
---|---|
DatasetProfile |
whylogs profile with statistics generated for the input dataset |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: WhylogsProfilerConfig,
context: StepContext,
) -> DatasetProfile:
"""Main entrypoint function for the whylogs profiler
Args:
dataset: pd.DataFrame, the given dataset
config: the configuration of the step
context: the context of the step
Returns:
whylogs profile with statistics generated for the input dataset
"""
whylogs_context = WhylogsContext(context)
profile = whylogs_context.profile_dataframe(
dataset, dataset_name=config.dataset_name, tags=config.tags
)
return profile
whylogs_profiler_step(step_name, enable_cache=None, dataset_name=None, dataset_timestamp=None, tags=None)
Shortcut function to create a new instance of the WhylogsProfilerStep step.
The returned WhylogsProfilerStep can be used in a pipeline to generate a whylogs DatasetProfile from a given pd.DataFrame and save it as an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
enable_cache |
Optional[bool] |
Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default |
None |
dataset_name |
Optional[str] |
the dataset name to be used for the whylogs profile (Optional). If not specified, the step name is used |
None |
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
None |
tags |
Optional[Dict[str, str]] |
custom metadata tags associated with the whylogs profile |
None |
Returns:
Type | Description |
---|---|
WhylogsProfilerStep |
a WhylogsProfilerStep step instance |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def whylogs_profiler_step(
step_name: str,
enable_cache: Optional[bool] = None,
dataset_name: Optional[str] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
tags: Optional[Dict[str, str]] = None,
) -> WhylogsProfilerStep:
"""Shortcut function to create a new instance of the WhylogsProfilerStep step.
The returned WhylogsProfilerStep can be used in a pipeline to generate a
whylogs DatasetProfile from a given pd.DataFrame and save it as an artifact.
Args:
step_name: The name of the step
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default
dataset_name: the dataset name to be used for the whylogs profile
(Optional). If not specified, the step name is used
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
tags: custom metadata tags associated with the whylogs profile
Returns:
a WhylogsProfilerStep step instance
"""
# enable cache explicitly to compensate for the fact that this step
# takes in a context object
if enable_cache is None:
enable_cache = True
step_type = type(
step_name,
(WhylogsProfilerStep,),
{
INSTANCE_CONFIGURATION: {
PARAM_ENABLE_CACHE: enable_cache,
PARAM_CREATED_BY_FUNCTIONAL_API: True,
},
},
)
return cast(
WhylogsProfilerStep,
step_type(
WhylogsProfilerConfig(
dataset_name=dataset_name,
dataset_timestamp=dataset_timestamp,
tags=tags,
)
),
)
visualizers
special
whylogs_visualizer
WhylogsPlots (StrEnum)
All supported whylogs plot types.
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsPlots(StrEnum):
"""All supported whylogs plot types."""
DISTRIBUTION = "plot_distribution"
MISSING_VALUES = "plot_missing_values"
UNIQUENESS = "plot_uniqueness"
DATA_TYPES = "plot_data_types"
STRING_LENGTH = "plot_string_length"
TOKEN_LENGTH = "plot_token_length"
CHAR_POS = "plot_char_pos"
STRING = "plot_string"
WhylogsVisualizer (BaseStepVisualizer)
The implementation of a Whylogs Visualizer.
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsVisualizer(BaseStepVisualizer):
"""The implementation of a Whylogs Visualizer."""
def visualize(
self,
object: StepView,
*args: Any,
plots: Optional[List[WhylogsPlots]] = None,
**kwargs: Any,
) -> None:
"""Visualize all whylogs dataset profiles present as outputs in the
step view
Args:
object: StepView fetched from run.get_step().
plots: optional list of whylogs plots to visualize. Defaults to
using all available plot types if not set
"""
whylogs_artifact_datatype = (
f"{DatasetProfile.__module__}.{DatasetProfile.__name__}"
)
for artifact_name, artifact_view in object.outputs.items():
# filter out anything but whylog dataset profile artifacts
if artifact_view.data_type == whylogs_artifact_datatype:
profile = artifact_view.read()
# whylogs doesn't currently support visualizing multiple
# non-related profiles side-by-side, so we open them in
# separate viewers for now
self.visualize_profile(artifact_name, profile, plots)
@staticmethod
def _get_plot_method(
visualizer: ProfileVisualizer, plot: WhylogsPlots
) -> Any:
"""Get the Whylogs ProfileVisualizer plot method corresponding to a
WhylogsPlots enum value.
Args:
visualizer: a ProfileVisualizer instance
plot: a WhylogsPlots enum value
Raises:
ValueError: if the supplied WhylogsPlots enum value does not
correspond to a valid ProfileVisualizer plot method
Returns:
The ProfileVisualizer plot method corresponding to the input
WhylogsPlots enum value
"""
plot_method = getattr(visualizer, plot, None)
if plot_method is None:
nl = "\n"
raise ValueError(
f"Invalid whylogs plot type: {plot} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(WhylogsPlots.names())}'
)
return plot_method
def visualize_profile(
self,
name: str,
profile: DatasetProfile,
plots: Optional[List[WhylogsPlots]] = None,
) -> None:
"""Generate a visualization of a whylogs dataset profile.
Args:
name: name identifying the profile if multiple profiles are
displayed at the same time
profile: whylogs DatasetProfile to visualize
plots: optional list of whylogs plots to visualize. Defaults to
using all available plot types if not set
"""
if Environment.in_notebook():
from IPython.core.display import display
if not plots:
# default to using all plots if none are supplied
plots = list(WhylogsPlots)
for column in sorted(profile.columns):
for plot in plots:
visualizer = ProfileVisualizer()
visualizer.set_profiles([profile])
plot_method = self._get_plot_method(visualizer, plot)
display(plot_method(column))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
delete=False, suffix=f"-{name}.html"
) as f:
logger.info("Opening %s in a new browser.." % f.name)
profile_viewer([profile], output_path=f.name)
visualize(self, object, *args, *, plots=None, **kwargs)
Visualize all whylogs dataset profiles present as outputs in the step view
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
plots |
Optional[List[zenml.integrations.whylogs.visualizers.whylogs_visualizer.WhylogsPlots]] |
optional list of whylogs plots to visualize. Defaults to using all available plot types if not set |
None |
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize(
self,
object: StepView,
*args: Any,
plots: Optional[List[WhylogsPlots]] = None,
**kwargs: Any,
) -> None:
"""Visualize all whylogs dataset profiles present as outputs in the
step view
Args:
object: StepView fetched from run.get_step().
plots: optional list of whylogs plots to visualize. Defaults to
using all available plot types if not set
"""
whylogs_artifact_datatype = (
f"{DatasetProfile.__module__}.{DatasetProfile.__name__}"
)
for artifact_name, artifact_view in object.outputs.items():
# filter out anything but whylog dataset profile artifacts
if artifact_view.data_type == whylogs_artifact_datatype:
profile = artifact_view.read()
# whylogs doesn't currently support visualizing multiple
# non-related profiles side-by-side, so we open them in
# separate viewers for now
self.visualize_profile(artifact_name, profile, plots)
visualize_profile(self, name, profile, plots=None)
Generate a visualization of a whylogs dataset profile.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
name identifying the profile if multiple profiles are displayed at the same time |
required |
profile |
DatasetProfile |
whylogs DatasetProfile to visualize |
required |
plots |
Optional[List[zenml.integrations.whylogs.visualizers.whylogs_visualizer.WhylogsPlots]] |
optional list of whylogs plots to visualize. Defaults to using all available plot types if not set |
None |
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize_profile(
self,
name: str,
profile: DatasetProfile,
plots: Optional[List[WhylogsPlots]] = None,
) -> None:
"""Generate a visualization of a whylogs dataset profile.
Args:
name: name identifying the profile if multiple profiles are
displayed at the same time
profile: whylogs DatasetProfile to visualize
plots: optional list of whylogs plots to visualize. Defaults to
using all available plot types if not set
"""
if Environment.in_notebook():
from IPython.core.display import display
if not plots:
# default to using all plots if none are supplied
plots = list(WhylogsPlots)
for column in sorted(profile.columns):
for plot in plots:
visualizer = ProfileVisualizer()
visualizer.set_profiles([profile])
plot_method = self._get_plot_method(visualizer, plot)
display(plot_method(column))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
delete=False, suffix=f"-{name}.html"
) as f:
logger.info("Opening %s in a new browser.." % f.name)
profile_viewer([profile], output_path=f.name)
whylogs_context
WhylogsContext
This is a step context extension that can be used to facilitate whylogs data logging and profiling inside a step function.
It acts as a wrapper built around the whylogs API that transparently incorporates ZenML specific information into the generated whylogs dataset profiles that can be used to associate whylogs profiles with the corresponding ZenML step run that produces them.
It also simplifies the whylogs profile generation process by abstracting away some of the whylogs specific details, such as whylogs session and logger initialization and management.
Source code in zenml/integrations/whylogs/whylogs_context.py
class WhylogsContext:
"""This is a step context extension that can be used to facilitate whylogs
data logging and profiling inside a step function.
It acts as a wrapper built around the whylogs API that transparently
incorporates ZenML specific information into the generated whylogs dataset
profiles that can be used to associate whylogs profiles with the
corresponding ZenML step run that produces them.
It also simplifies the whylogs profile generation process by abstracting
away some of the whylogs specific details, such as whylogs session and
logger initialization and management.
"""
_session: Session = None
def __init__(
self,
step_context: StepContext,
project: Optional[str] = None,
pipeline: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
"""Create a ZenML whylogs context based on a generic step context.
Args:
step_context: a StepContext instance that provides information
about the currently running step, such as the step name
project: optional project name to use for the whylogs session
pipeline: optional pipeline name to use for the whylogs session
tags: optional list of tags to apply to all whylogs profiles
generated through this context
"""
self._step_context = step_context
self._project = project
self._pipeline = pipeline
self._tags = tags
def get_whylogs_session(
self,
) -> Session:
"""Get the whylogs session associated with the current step.
Returns:
The whylogs Session instance associated with the current step
"""
if self._session is not None:
return self._session
self._session = Session(
project=self._project or self._step_context.step_name,
pipeline=self._pipeline or self._step_context.step_name,
# keeping the writers list empty, serialization is done in the
# materializer
writers=[],
)
return self._session
def profile_dataframe(
self,
df: pd.DataFrame,
dataset_name: Optional[str] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
tags: Optional[Dict[str, str]] = None,
) -> DatasetProfile:
"""Generate whylogs statistics for a Pandas dataframe.
Args:
df: a Pandas dataframe to profile.
dataset_name: the name of the dataset (Optional). If not specified,
the pipeline step name is used
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
tags: custom metadata tags associated with the whylogs profile
Returns:
A whylogs DatasetProfile with the statistics generated from the
input dataset.
"""
session = self.get_whylogs_session()
# TODO [ENG-437]: use a default whylogs dataset_name that is unique across
# multiple pipelines
dataset_name = dataset_name or self._step_context.step_name
final_tags = self._tags.copy() if self._tags else dict()
# TODO [ENG-438]: add more zenml specific tags to the whylogs profile, such
# as the pipeline name and run ID
final_tags["zenml.step"] = self._step_context.step_name
# the datasetId tag is used to identify dataset profiles in whylabs.
# dataset profiles with the same datasetID are considered to belong
# to the same dataset/model.
final_tags.setdefault("datasetId", dataset_name)
if tags:
final_tags.update(tags)
logger = session.logger(
dataset_name, dataset_timestamp=dataset_timestamp, tags=final_tags
)
logger.log_dataframe(df)
return logger.close()
__init__(self, step_context, project=None, pipeline=None, tags=None)
special
Create a ZenML whylogs context based on a generic step context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_context |
StepContext |
a StepContext instance that provides information about the currently running step, such as the step name |
required |
project |
Optional[str] |
optional project name to use for the whylogs session |
None |
pipeline |
Optional[str] |
optional pipeline name to use for the whylogs session |
None |
tags |
Optional[Dict[str, str]] |
optional list of tags to apply to all whylogs profiles generated through this context |
None |
Source code in zenml/integrations/whylogs/whylogs_context.py
def __init__(
self,
step_context: StepContext,
project: Optional[str] = None,
pipeline: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
"""Create a ZenML whylogs context based on a generic step context.
Args:
step_context: a StepContext instance that provides information
about the currently running step, such as the step name
project: optional project name to use for the whylogs session
pipeline: optional pipeline name to use for the whylogs session
tags: optional list of tags to apply to all whylogs profiles
generated through this context
"""
self._step_context = step_context
self._project = project
self._pipeline = pipeline
self._tags = tags
get_whylogs_session(self)
Get the whylogs session associated with the current step.
Returns:
Type | Description |
---|---|
Session |
The whylogs Session instance associated with the current step |
Source code in zenml/integrations/whylogs/whylogs_context.py
def get_whylogs_session(
self,
) -> Session:
"""Get the whylogs session associated with the current step.
Returns:
The whylogs Session instance associated with the current step
"""
if self._session is not None:
return self._session
self._session = Session(
project=self._project or self._step_context.step_name,
pipeline=self._pipeline or self._step_context.step_name,
# keeping the writers list empty, serialization is done in the
# materializer
writers=[],
)
return self._session
profile_dataframe(self, df, dataset_name=None, dataset_timestamp=None, tags=None)
Generate whylogs statistics for a Pandas dataframe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
DataFrame |
a Pandas dataframe to profile. |
required |
dataset_name |
Optional[str] |
the name of the dataset (Optional). If not specified, the pipeline step name is used |
None |
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
None |
tags |
Optional[Dict[str, str]] |
custom metadata tags associated with the whylogs profile |
None |
Returns:
Type | Description |
---|---|
DatasetProfile |
A whylogs DatasetProfile with the statistics generated from the input dataset. |
Source code in zenml/integrations/whylogs/whylogs_context.py
def profile_dataframe(
self,
df: pd.DataFrame,
dataset_name: Optional[str] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
tags: Optional[Dict[str, str]] = None,
) -> DatasetProfile:
"""Generate whylogs statistics for a Pandas dataframe.
Args:
df: a Pandas dataframe to profile.
dataset_name: the name of the dataset (Optional). If not specified,
the pipeline step name is used
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
tags: custom metadata tags associated with the whylogs profile
Returns:
A whylogs DatasetProfile with the statistics generated from the
input dataset.
"""
session = self.get_whylogs_session()
# TODO [ENG-437]: use a default whylogs dataset_name that is unique across
# multiple pipelines
dataset_name = dataset_name or self._step_context.step_name
final_tags = self._tags.copy() if self._tags else dict()
# TODO [ENG-438]: add more zenml specific tags to the whylogs profile, such
# as the pipeline name and run ID
final_tags["zenml.step"] = self._step_context.step_name
# the datasetId tag is used to identify dataset profiles in whylabs.
# dataset profiles with the same datasetID are considered to belong
# to the same dataset/model.
final_tags.setdefault("datasetId", dataset_name)
if tags:
final_tags.update(tags)
logger = session.logger(
dataset_name, dataset_timestamp=dataset_timestamp, tags=final_tags
)
logger.log_dataframe(df)
return logger.close()
whylogs_step_decorator
enable_whylogs(_step=None, *, project=None, pipeline=None, tags=None)
Decorator to enable whylogs profiling for a step function.
Apply this decorator to a ZenML pipeline step to enable whylogs profiling.
The decorated function will be given access to a StepContext whylogs
field that facilitates access to the whylogs dataset profiling API,
like so:
@enable_whylogs
@step(enable_cache=True)
def data_loader(
context: StepContext,
) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
...
data = pd.DataFrame(...)
profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
...
return data, profile
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_step |
Optional[~S] |
The decorated step class. |
None |
project |
Optional[str] |
optional project name to use for the whylogs session |
None |
pipeline |
Optional[str] |
optional pipeline name to use for the whylogs session |
None |
tags |
Optional[Dict[str, str]] |
optional list of tags to apply to all profiles generated by this step |
None |
Returns:
Type | Description |
---|---|
Union[~S, Callable[[~S], ~S]] |
the inner decorator which enhaces the input step class with whylogs profiling functionality |
Source code in zenml/integrations/whylogs/whylogs_step_decorator.py
def enable_whylogs(
_step: Optional[S] = None,
*,
project: Optional[str] = None,
pipeline: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> Union[S, Callable[[S], S]]:
"""Decorator to enable whylogs profiling for a step function.
Apply this decorator to a ZenML pipeline step to enable whylogs profiling.
The decorated function will be given access to a StepContext `whylogs`
field that facilitates access to the whylogs dataset profiling API,
like so:
```python
@enable_whylogs
@step(enable_cache=True)
def data_loader(
context: StepContext,
) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
...
data = pd.DataFrame(...)
profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
...
return data, profile
```
Args:
_step: The decorated step class.
project: optional project name to use for the whylogs session
pipeline: optional pipeline name to use for the whylogs session
tags: optional list of tags to apply to all profiles generated by this
step
Returns:
the inner decorator which enhaces the input step class with whylogs
profiling functionality
"""
def inner_decorator(_step: S) -> S:
source_fn = _step.entrypoint
return cast(
S,
type( # noqa
_step.__name__,
(_step,),
{
STEP_INNER_FUNC_NAME: staticmethod(
whylogs_entrypoint(project, pipeline, tags)(source_fn)
),
"__module__": _step.__module__,
},
),
)
if _step is None:
return inner_decorator
else:
return inner_decorator(_step)
whylogs_entrypoint(project=None, pipeline=None, tags=None)
Decorator for a step entrypoint to enable whylogs.
Apply this decorator to a ZenML pipeline step to enable whylogs profiling.
The decorated function will be given access to a StepContext whylogs
field that facilitates access to the whylogs dataset profiling API,
like so:
.. highlight:: python .. code-block:: python
@step(enable_cache=True)
@whylogs_entrypoint()
def data_loader(
context: StepContext,
) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
...
data = pd.DataFrame(...)
profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
...
return data, profile
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project |
Optional[str] |
optional project name to use for the whylogs session |
None |
pipeline |
Optional[str] |
optional pipeline name to use for the whylogs session |
None |
tags |
Optional[Dict[str, str]] |
optional list of tags to apply to all profiles generated by this step |
None |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
the input function enhanced with whylogs profiling functionality |
Source code in zenml/integrations/whylogs/whylogs_step_decorator.py
def whylogs_entrypoint(
project: Optional[str] = None,
pipeline: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> Callable[[F], F]:
"""Decorator for a step entrypoint to enable whylogs.
Apply this decorator to a ZenML pipeline step to enable whylogs profiling.
The decorated function will be given access to a StepContext `whylogs`
field that facilitates access to the whylogs dataset profiling API,
like so:
.. highlight:: python
.. code-block:: python
@step(enable_cache=True)
@whylogs_entrypoint()
def data_loader(
context: StepContext,
) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
...
data = pd.DataFrame(...)
profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
...
return data, profile
Args:
project: optional project name to use for the whylogs session
pipeline: optional pipeline name to use for the whylogs session
tags: optional list of tags to apply to all profiles generated by this
step
Returns:
the input function enhanced with whylogs profiling functionality
"""
def inner_decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa
for arg in args + tuple(kwargs.values()):
if isinstance(arg, StepContext):
arg.__dict__["whylogs"] = WhylogsContext(
arg, project, pipeline, tags
)
break
return func(*args, **kwargs)
return cast(F, wrapper)
return inner_decorator