Integrations
zenml.integrations
special
ZenML integrations module.
The ZenML integrations module contains sub-modules for each integration that we
support. This includes orchestrators like Apache Airflow, visualization tools
like the facets
library, as well as deep learning libraries like PyTorch.
airflow
special
Airflow integration for ZenML.
The Airflow integration sub-module powers an alternative to the local
orchestrator. You can enable it by registering the Airflow orchestrator with
the CLI tool, then bootstrap using the zenml orchestrator up
command.
AirflowIntegration (Integration)
Definition of Airflow Integration for ZenML.
Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
"""Definition of Airflow Integration for ZenML."""
NAME = AIRFLOW
REQUIREMENTS = ["apache-airflow==2.2.0"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AIRFLOW_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the Airflow integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AIRFLOW_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
)
]
orchestrators
special
The Airflow integration enables the use of Airflow as a pipeline orchestrator.
airflow_orchestrator
Implementation of Airflow orchestrator integration.
AirflowOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines using Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Airflow."""
airflow_home: str = ""
# Class Configuration
FLAVOR: ClassVar[str] = AIRFLOW_ORCHESTRATOR_FLAVOR
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self._set_env()
@staticmethod
def _translate_schedule(
schedule: Optional[Schedule] = None,
) -> Dict[str, Any]:
"""Convert ZenML schedule into Airflow schedule.
The Airflow schedule uses slightly different naming and needs some
default entries for execution without a schedule.
Args:
schedule: Containing the interval, start and end date and
a boolean flag that defines if past runs should be caught up
on
Returns:
Airflow configuration dict.
"""
if schedule:
if schedule.cron_expression:
return {
"schedule_interval": schedule.cron_expression,
}
else:
return {
"schedule_interval": schedule.interval_second,
"start_date": schedule.start_time,
"end_date": schedule.end_time,
"catchup": schedule.catchup,
}
return {
"schedule_interval": "@once",
# set the a start time in the past and disable catchup so airflow runs the dag immediately
"start_date": datetime.datetime.now() - datetime.timedelta(7),
"catchup": False,
}
def prepare_or_run_pipeline(
self,
sorted_steps: List[BaseStep],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment
and used for orchestration of the pipeline.
How it works:
-------------
A new airflow_dag is instantiated with the pipeline name and among
others things the run schedule.
For each step of the pipeline a callable is created. This callable
uses the run_step() method to execute the step. The parameters of
this callable are pre-filled and an airflow step_operator is created
within the dag. The dependencies to upstream steps are then
configured.
Finally, the dag is fully complete and can be returned.
Args:
sorted_steps: List of steps in the pipeline.
pipeline: The pipeline to be executed.
pb2_pipeline: The pipeline as a protobuf message.
stack: The stack on which the pipeline will be deployed.
runtime_configuration: The runtime configuration.
Returns:
The Airflow DAG.
"""
import airflow
from airflow.operators import python as airflow_python
# Instantiate and configure airflow Dag with name and schedule
airflow_dag = airflow.DAG(
dag_id=pipeline.name,
is_paused_upon_creation=False,
**self._translate_schedule(runtime_configuration.schedule),
)
# Dictionary mapping step names to airflow_operators. This will be needed
# to configure airflow operator dependencies
step_name_to_airflow_operator = {}
for step in sorted_steps:
# Create callable that will be used by airflow to execute the step
# within the orchestrated environment
def _step_callable(step_instance: "BaseStep", **kwargs):
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Airflow orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
# Extract run name for the kwargs that will be passed to the
# callable
run_name = kwargs["ti"].get_dagrun().run_id
self.run_step(
step=step_instance,
run_name=run_name,
pb2_pipeline=pb2_pipeline,
)
# Create airflow python operator that contains the step callable
airflow_operator = airflow_python.PythonOperator(
dag=airflow_dag,
task_id=step.name,
provide_context=True,
python_callable=functools.partial(
_step_callable, step_instance=step
),
)
# Configure the current airflow operator to run after all upstream
# operators finished executing
step_name_to_airflow_operator[step.name] = airflow_operator
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
airflow_operator.set_upstream(
step_name_to_airflow_operator[upstream_step_name]
)
# Return the finished airflow dag
return airflow_dag
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets Airflow home according to orchestrator UUID.
Args:
values: Dictionary containing all orchestrator attributes values.
Returns:
Dictionary containing all orchestrator attributes values and the airflow home.
Raises:
ValueError: If the orchestrator UUID is not set.
"""
if "uuid" not in values:
raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
values["airflow_home"] = os.path.join(
io_utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(values["uuid"]),
)
return values
@property
def dags_directory(self) -> str:
"""Returns path to the airflow dags directory.
Returns:
Path to the airflow dags directory.
"""
return os.path.join(self.airflow_home, "dags")
@property
def pid_file(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.airflow_home, "airflow_daemon.pid")
@property
def log_file(self) -> str:
"""Returns path to the airflow log file.
Returns:
str: Path to the airflow log file.
"""
return os.path.join(self.airflow_home, "airflow_orchestrator.log")
@property
def password_file(self) -> str:
"""Returns path to the webserver password file.
Returns:
Path to the webserver password file.
"""
return os.path.join(self.airflow_home, "standalone_admin_password.txt")
def _set_env(self) -> None:
"""Sets environment variables to configure airflow."""
os.environ["AIRFLOW_HOME"] = self.airflow_home
os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
# check the DAG folder every 10 seconds for new files
os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"
def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
"""Copies DAG module to the Airflow DAGs directory if not already present.
Args:
dag_filepath: Path to the file in which the DAG is defined.
"""
dags_directory = io_utils.resolve_relative_path(self.dags_directory)
if dags_directory == os.path.dirname(dag_filepath):
logger.debug("File is already in airflow DAGs directory.")
else:
logger.debug(
"Copying dag file '%s' to DAGs directory.", dag_filepath
)
destination_path = os.path.join(
dags_directory, os.path.basename(dag_filepath)
)
if fileio.exists(destination_path):
logger.info(
"File '%s' already exists, overwriting with new DAG file",
destination_path,
)
fileio.copy(dag_filepath, destination_path, overwrite=True)
def _log_webserver_credentials(self) -> None:
"""Logs URL and credentials to log in to the airflow webserver.
Raises:
FileNotFoundError: If the password file does not exist.
"""
if fileio.exists(self.password_file):
with open(self.password_file) as file:
password = file.read().strip()
else:
raise FileNotFoundError(
f"Can't find password file '{self.password_file}'"
)
logger.info(
"To inspect your DAGs, login to http://0.0.0.0:8080 "
"with username: admin password: %s",
password,
)
def runtime_options(self) -> Dict[str, Any]:
"""Runtime options for the airflow orchestrator.
Returns:
Runtime options dictionary.
"""
return {DAG_FILEPATH_OPTION_KEY: None}
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Checks Airflow is running and copies DAG file to the Airflow DAGs directory.
Args:
pipeline: Pipeline to be deployed.
stack: Stack to be deployed.
runtime_configuration: Runtime configuration for the pipeline.
Raises:
RuntimeError: If Airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
if Environment.in_notebook():
raise RuntimeError(
"Unable to run the Airflow orchestrator from within a "
"notebook. Airflow requires a python file which contains a "
"global Airflow DAG object and therefore does not work with "
"notebooks. Please copy your ZenML pipeline code in a python "
"file and try again."
)
try:
dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
@property
def is_running(self) -> bool:
"""Returns whether the airflow daemon is currently running.
Returns:
True if the daemon is running, False otherwise.
Raises:
RuntimeError: If port 8080 is occupied.
"""
from airflow.cli.commands.standalone_command import StandaloneCommand
from airflow.jobs.triggerer_job import TriggererJob
daemon_running = daemon.check_if_daemon_is_running(self.pid_file)
command = StandaloneCommand()
webserver_port_open = command.port_open(8080)
if not daemon_running:
if webserver_port_open:
raise RuntimeError(
"The airflow daemon does not seem to be running but "
"local port 8080 is occupied. Make sure the port is "
"available and try again."
)
# exit early so we don't check non-existing airflow databases
return False
# we can't use StandaloneCommand().is_ready() here as the
# Airflow SequentialExecutor apparently does not send a heartbeat
# while running a task which would result in this returning `False`
# even if Airflow is running.
airflow_running = webserver_port_open and command.job_running(
TriggererJob
)
return airflow_running
@property
def is_provisioned(self) -> bool:
"""Returns whether the airflow daemon is currently running.
Returns:
True if the airflow daemon is running, False otherwise.
"""
return self.is_running
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.exists(self.dags_directory):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=get_source_root_path(),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rmtree(self.airflow_home)
logger.info("Airflow spun down.")
dags_directory: str
property
readonly
Returns path to the airflow dags directory.
Returns:
Type | Description |
---|---|
str |
Path to the airflow dags directory. |
is_provisioned: bool
property
readonly
Returns whether the airflow daemon is currently running.
Returns:
Type | Description |
---|---|
bool |
True if the airflow daemon is running, False otherwise. |
is_running: bool
property
readonly
Returns whether the airflow daemon is currently running.
Returns:
Type | Description |
---|---|
bool |
True if the daemon is running, False otherwise. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If port 8080 is occupied. |
log_file: str
property
readonly
Returns path to the airflow log file.
Returns:
Type | Description |
---|---|
str |
Path to the airflow log file. |
password_file: str
property
readonly
Returns path to the webserver password file.
Returns:
Type | Description |
---|---|
str |
Path to the webserver password file. |
pid_file: str
property
readonly
Returns path to the daemon PID file.
Returns:
Type | Description |
---|---|
str |
Path to the daemon PID file. |
__init__(self, **values)
special
Sets environment variables to configure airflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**values |
Any |
Values to set in the orchestrator. |
{} |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self._set_env()
deprovision(self)
Stops the airflow daemon if necessary and tears down resources.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rmtree(self.airflow_home)
logger.info("Airflow spun down.")
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)
Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment and used for orchestration of the pipeline.
How it works:
A new airflow_dag is instantiated with the pipeline name and among others things the run schedule.
For each step of the pipeline a callable is created. This callable uses the run_step() method to execute the step. The parameters of this callable are pre-filled and an airflow step_operator is created within the dag. The dependencies to upstream steps are then configured.
Finally, the dag is fully complete and can be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sorted_steps |
List[zenml.steps.base_step.BaseStep] |
List of steps in the pipeline. |
required |
pipeline |
BasePipeline |
The pipeline to be executed. |
required |
pb2_pipeline |
Pipeline |
The pipeline as a protobuf message. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration. |
required |
Returns:
Type | Description |
---|---|
Any |
The Airflow DAG. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
self,
sorted_steps: List[BaseStep],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment
and used for orchestration of the pipeline.
How it works:
-------------
A new airflow_dag is instantiated with the pipeline name and among
others things the run schedule.
For each step of the pipeline a callable is created. This callable
uses the run_step() method to execute the step. The parameters of
this callable are pre-filled and an airflow step_operator is created
within the dag. The dependencies to upstream steps are then
configured.
Finally, the dag is fully complete and can be returned.
Args:
sorted_steps: List of steps in the pipeline.
pipeline: The pipeline to be executed.
pb2_pipeline: The pipeline as a protobuf message.
stack: The stack on which the pipeline will be deployed.
runtime_configuration: The runtime configuration.
Returns:
The Airflow DAG.
"""
import airflow
from airflow.operators import python as airflow_python
# Instantiate and configure airflow Dag with name and schedule
airflow_dag = airflow.DAG(
dag_id=pipeline.name,
is_paused_upon_creation=False,
**self._translate_schedule(runtime_configuration.schedule),
)
# Dictionary mapping step names to airflow_operators. This will be needed
# to configure airflow operator dependencies
step_name_to_airflow_operator = {}
for step in sorted_steps:
# Create callable that will be used by airflow to execute the step
# within the orchestrated environment
def _step_callable(step_instance: "BaseStep", **kwargs):
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Airflow orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
# Extract run name for the kwargs that will be passed to the
# callable
run_name = kwargs["ti"].get_dagrun().run_id
self.run_step(
step=step_instance,
run_name=run_name,
pb2_pipeline=pb2_pipeline,
)
# Create airflow python operator that contains the step callable
airflow_operator = airflow_python.PythonOperator(
dag=airflow_dag,
task_id=step.name,
provide_context=True,
python_callable=functools.partial(
_step_callable, step_instance=step
),
)
# Configure the current airflow operator to run after all upstream
# operators finished executing
step_name_to_airflow_operator[step.name] = airflow_operator
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
airflow_operator.set_upstream(
step_name_to_airflow_operator[upstream_step_name]
)
# Return the finished airflow dag
return airflow_dag
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Checks Airflow is running and copies DAG file to the Airflow DAGs directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
Pipeline to be deployed. |
required |
stack |
Stack |
Stack to be deployed. |
required |
runtime_configuration |
RuntimeConfiguration |
Runtime configuration for the pipeline. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If Airflow is not running or no DAG filepath runtime option is provided. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Checks Airflow is running and copies DAG file to the Airflow DAGs directory.
Args:
pipeline: Pipeline to be deployed.
stack: Stack to be deployed.
runtime_configuration: Runtime configuration for the pipeline.
Raises:
RuntimeError: If Airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
if Environment.in_notebook():
raise RuntimeError(
"Unable to run the Airflow orchestrator from within a "
"notebook. Airflow requires a python file which contains a "
"global Airflow DAG object and therefore does not work with "
"notebooks. Please copy your ZenML pipeline code in a python "
"file and try again."
)
try:
dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)
Ensures that Airflow is running.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.exists(self.dags_directory):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=get_source_root_path(),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
runtime_options(self)
Runtime options for the airflow orchestrator.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Runtime options dictionary. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def runtime_options(self) -> Dict[str, Any]:
"""Runtime options for the airflow orchestrator.
Returns:
Runtime options dictionary.
"""
return {DAG_FILEPATH_OPTION_KEY: None}
set_airflow_home(values)
classmethod
Sets Airflow home according to orchestrator UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
Dict[str, Any] |
Dictionary containing all orchestrator attributes values. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Dictionary containing all orchestrator attributes values and the airflow home. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the orchestrator UUID is not set. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets Airflow home according to orchestrator UUID.
Args:
values: Dictionary containing all orchestrator attributes values.
Returns:
Dictionary containing all orchestrator attributes values and the airflow home.
Raises:
ValueError: If the orchestrator UUID is not set.
"""
if "uuid" not in values:
raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
values["airflow_home"] = os.path.join(
io_utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(values["uuid"]),
)
return values
aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = ["boto3==1.21.0", "sagemaker==2.82.2"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AWS_SECRET_MANAGER_FLAVOR,
source="zenml.integrations.aws.secrets_managers"
".AWSSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_CONTAINER_REGISTRY_FLAVOR,
source="zenml.integrations.aws.container_registries"
".AWSContainerRegistry",
type=StackComponentType.CONTAINER_REGISTRY,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.aws.step_operators"
".SagemakerStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AWS_SECRET_MANAGER_FLAVOR,
source="zenml.integrations.aws.secrets_managers"
".AWSSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_CONTAINER_REGISTRY_FLAVOR,
source="zenml.integrations.aws.container_registries"
".AWSContainerRegistry",
type=StackComponentType.CONTAINER_REGISTRY,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.aws.step_operators"
".SagemakerStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
pydantic-model
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
# Class Configuration
FLAVOR: ClassVar[str] = AWS_CONTAINER_REGISTRY_FLAVOR
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.uri)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.uri}."
)
return match.group(1)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
if not repo_exists:
match = re.search(f"{self.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for example run a pipeline "
"using our Kubeflow orchestrator, ZenML will automatically build a "
f"docker image called `{self.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create the "
f"repository `zenml-kubeflow` inside your amazon registry."
)
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
if not repo_exists:
match = re.search(f"{self.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
secrets_managers
special
AWS Secrets Manager.
aws_secrets_manager
Implementation of the AWS Secrets Manager integration.
AWSSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
region_name: str
# Class configuration
FLAVOR: ClassVar[str] = AWS_SECRET_MANAGER_FLAVOR
SUPPORTS_SCOPING: ClassVar[bool] = True
CLIENT: ClassVar[Any] = None
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
cls.validate_secret_name_or_namespace(namespace)
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
"""Ensure that the client is connected to the AWS secrets manager.
Args:
region_name: the AWS region name
"""
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
def _get_secret_tags(
self, secret: BaseSecretSchema
) -> List[Dict[str, str]]:
"""Return a list of AWS secret tag values for a given secret.
Args:
secret: the secret object
Returns:
A list of AWS secret tag values
"""
metadata = self._get_secret_metadata(secret)
return [{"Key": k, "Value": v} for k, v in metadata.items()]
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return a list of AWS filters for the entire scope or just a scoped secret.
These filters can be used when querying the AWS Secrets Manager
for all secrets or for a single secret available in the configured
scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html
Example AWS filters for all secrets in the current (namespace) scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Example AWS filters for a particular secret in the current (namespace)
scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_secret_name"],
},
{
"Key: "tag-value",
"Values": ["my_secret"],
},
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of AWS filters uniquely identifying all secrets
or a named secret within the configured scope.
"""
metadata = self._get_secret_scope_metadata(secret_name)
filters: List[Dict[str, Any]] = []
for k, v in metadata.items():
filters.append(
{
"Key": "tag-key",
"Values": [
k,
],
}
)
filters.append(
{
"Key": "tag-value",
"Values": [
str(v),
],
}
)
return filters
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected(self.region_name)
filters: List[Dict[str, Any]] = []
prefix: Optional[str] = None
if self.scope == SecretsManagerScope.NONE:
# unscoped (legacy) secrets don't have tags. We want to filter out
# non-legacy secrets
filters = [
{
"Key": "tag-key",
"Values": [
"!zenml_scope",
],
},
]
if secret_name:
prefix = secret_name
else:
filters = self._get_secret_scope_filters()
if secret_name:
prefix = self._get_scoped_secret_name(secret_name)
else:
# add the name prefix to the filters to account for the fact
# that AWS does not do exact matching but prefix-matching on the
# filters
prefix = self._get_scoped_secret_name_prefix()
if prefix:
filters.append(
{
"Key": "name",
"Values": [
f"{prefix}",
],
}
)
# TODO [ENG-720]: Deal with pagination in the aws secret manager when
# listing all secrets
# TODO [ENG-721]: take out this magic maxresults number
response = self.CLIENT.list_secrets(MaxResults=100, Filters=filters)
results = []
for secret in response["SecretList"]:
name = self._get_unscoped_secret_name(secret["Name"])
# keep only the names that are in scope and filter by secret name,
# if one was given
if name and (not secret_name or secret_name == name):
results.append(name)
return results
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_all_secrets(self)
Delete all existing secrets.
This method will force delete all your secrets. You will not be able to recover them once this method is called.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
validate_secret_name_or_namespace(name)
classmethod
Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The /
character is only used internally to delimit
scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator)
pydantic-model
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
instance_type |
str |
The type of the compute instance where jobs will run. |
base_image |
Optional[str] |
The base image to use for building the docker image that will be executed. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
instance_type: The type of the compute instance where jobs will run.
base_image: The base image to use for building the docker
image that will be executed.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
"""
role: str
instance_type: str
base_image: Optional[str] = None
bucket: Optional[str] = None
experiment_name: Optional[str] = None
# Class Configuration
FLAVOR: ClassVar[str] = AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Returns:
A validator that checks that the stack contains a container registry.
"""
def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
return (
stack.orchestrator.FLAVOR == "local",
"Local orchestrator is required",
)
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_ensure_local_orchestrator,
)
def _build_docker_image(
self,
pipeline_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> str:
repo = Repository()
container_registry = repo.active_stack.container_registry
if not container_registry:
raise RuntimeError("Missing container registry")
registry_uri = container_registry.uri.rstrip("/")
image_name = f"{registry_uri}/zenml-sagemaker:{pipeline_name}"
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
entrypoint=" ".join(entrypoint_command),
requirements=set(requirements),
base_image=self.base_image,
)
container_registry.push_image(image_name)
return docker_utils.get_image_digest(image_name) or image_name
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on Sagemaker.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
resource_configuration: The resource configuration for this step.
"""
image_name = self._build_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
session = sagemaker.Session(default_bucket=self.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.role,
instance_count=1,
instance_type=self.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = run_name.replace("_", "-")
experiment_config = {}
if self.experiment_name:
experiment_config = {
"ExperimentName": self.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a container registry. |
launch(self, pipeline_name, run_name, requirements, entrypoint_command, resource_configuration)
Launches a step on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
resource_configuration |
ResourceConfiguration |
The resource configuration for this step. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on Sagemaker.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
resource_configuration: The resource configuration for this step.
"""
image_name = self._build_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
session = sagemaker.Session(default_bucket=self.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.role,
instance_count=1,
instance_type=self.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = run_name.replace("_", "-")
experiment_config = {}
if self.experiment_name:
experiment_config = {
"ExperimentName": self.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
azure
special
Initialization of the ZenML Azure integration.
The Azure integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores,
and an io
module to handle file operations on Azure Blob Storage.
The Azure Step Operator integration submodule provides a way to run ZenML steps
in AzureML.
AzureIntegration (Integration)
Definition of Azure integration for ZenML.
Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
"""Definition of Azure integration for ZenML."""
NAME = AZURE
REQUIREMENTS = [
"adlfs==2021.10.0",
"azure-keyvault-keys",
"azure-keyvault-secrets",
"azure-identity",
"azureml-core==1.42.0.post1",
]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declares the flavors for the integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AZURE_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.azure.artifact_stores"
".AzureArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=AZURE_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.azure.secrets_managers"
".AzureSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AZUREML_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.azure.step_operators"
".AzureMLStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declares the flavors for the integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/azure/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declares the flavors for the integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AZURE_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.azure.artifact_stores"
".AzureArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=AZURE_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.azure.secrets_managers"
".AzureSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AZUREML_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.azure.step_operators"
".AzureMLStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
artifact_stores
special
Initialization of the Azure Artifact Store integration.
azure_artifact_store
Implementation of the Azure Artifact Store integration.
AzureArtifactStore (BaseArtifactStore, AuthenticationMixin)
pydantic-model
Artifact Store for Microsoft Azure based artifacts.
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for Microsoft Azure based artifacts."""
_filesystem: Optional[adlfs.AzureBlobFileSystem] = None
# Class Configuration
FLAVOR: ClassVar[str] = AZURE_ARTIFACT_STORE_FLAVOR
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"}
@property
def filesystem(self) -> adlfs.AzureBlobFileSystem:
"""The adlfs filesystem to access this artifact store.
Returns:
The adlfs filesystem to access this artifact store.
"""
if not self._filesystem:
secret = self.get_authentication_secret(
expected_schema_type=AzureSecretSchema
)
credentials = secret.content if secret else {}
self._filesystem = adlfs.AzureBlobFileSystem(
**credentials,
anon=False,
use_listings_cache=False,
)
return self._filesystem
@classmethod
def _split_path(cls, path: PathType) -> Tuple[str, str]:
"""Splits a path into the filesystem prefix and remainder.
Example:
```python
prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
print(prefix, remainder) # "az://" "my_container/test.txt"
```
Args:
path: The path to split.
Returns:
A tuple of the filesystem prefix and the remainder.
"""
path = convert_to_str(path)
prefix = ""
for potential_prefix in cls.SUPPORTED_SCHEMES:
if path.startswith(potential_prefix):
prefix = potential_prefix
path = path[len(potential_prefix) :]
break
return prefix, path
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
prefix, _ = self._split_path(pattern)
return [
f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of files in the given directory.
"""
_, path = self._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a dictionary returned by the Azure filesystem.
Args:
file_dict: A dictionary returned by the Azure filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
Stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
prefix, _ = self._split_path(top)
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
filesystem: AzureBlobFileSystem
property
readonly
The adlfs filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
AzureBlobFileSystem |
The adlfs filesystem to access this artifact store. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
prefix, _ = self._split_path(pattern)
return [
f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of files in the given directory. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of files in the given directory.
"""
_, path = self._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a dictionary returned by the Azure filesystem.
Args:
file_dict: A dictionary returned by the Azure filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to remove. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Stat info. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
Stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
prefix, _ = self._split_path(top)
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
secrets_managers
special
Initialization of the Azure Secrets Manager integration.
azure_secrets_manager
Implementation of the Azure Secrets Manager integration.
AzureSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the Azure secrets manager.
Attributes:
Name | Type | Description |
---|---|---|
key_vault_name |
str |
Name of an Azure Key Vault that this secrets manager will use to store secrets. |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
class AzureSecretsManager(BaseSecretsManager):
"""Class to interact with the Azure secrets manager.
Attributes:
key_vault_name: Name of an Azure Key Vault that this secrets manager
will use to store secrets.
"""
key_vault_name: str
# Class configuration
FLAVOR: ClassVar[str] = AZURE_SECRETS_MANAGER_FLAVOR
CLIENT: ClassVar[Any] = None
@classmethod
def _ensure_client_connected(cls, vault_name: str) -> None:
if cls.CLIENT is None:
KVUri = f"https://{vault_name}.vault.azure.net"
credential = DefaultAzureCredential()
cls.CLIENT = SecretClient(vault_url=KVUri, credential=credential)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self._ensure_client_connected(self.key_vault_name)
if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)
self.update_secret(secret)
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
RuntimeError: if the secret does not exist
ValueError: if the secret is named 'name'
"""
self._ensure_client_connected(self.key_vault_name)
secret_contents = {}
zenml_schema_name = ""
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")
if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")
response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value
zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
if not secret_contents:
raise RuntimeError(f"No secrets found within the {secret_name}")
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
return secret_schema(**secret_contents)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
self._ensure_client_connected(self.key_vault_name)
set_of_secrets = set()
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
return list(set_of_secrets)
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
"""
self._ensure_client_connected(self.key_vault_name)
for key, value in secret.content.items():
encoded_key = base64.b64encode(
f"{secret.name}-{key}".encode()
).hex()
azure_secret_name = f"zenml-{encoded_key}"
self.CLIENT.set_secret(azure_secret_name, value)
self.CLIENT.update_secret_properties(
azure_secret_name,
tags={
ZENML_GROUP_KEY: secret.name,
ZENML_KEY_NAME: key,
ZENML_SCHEMA_NAME: secret.TYPE,
},
)
logger.debug("Wrote secret: %s", azure_secret_name)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
In Azure a secret is a single k-v pair. Within ZenML a secret is a
collection of k-v pairs. As such, deleting a secret will iterate through
all secrets and delete the ones with the secret_name as label.
Args:
secret_name: the name of the secret to delete
"""
self._ensure_client_connected(self.key_vault_name)
# Go through all Azure secrets and delete the ones with the secret_name
# as label.
for secret_property in self.CLIENT.list_properties_of_secrets():
response = self.CLIENT.get_secret(secret_property.name)
tags = response.properties.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
self.CLIENT.begin_delete_secret(secret_property.name).result()
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected(self.key_vault_name)
# List all secrets.
for secret_property in self.CLIENT.list_properties_of_secrets():
response = self.CLIENT.get_secret(secret_property.name)
tags = response.properties.tags
if tags and (ZENML_GROUP_KEY in tags or ZENML_SCHEMA_NAME in tags):
logger.info(
"Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
secret_property.name,
tags.get(ZENML_GROUP_KEY),
)
self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected(self.key_vault_name)
# List all secrets.
for secret_property in self.CLIENT.list_properties_of_secrets():
response = self.CLIENT.get_secret(secret_property.name)
tags = response.properties.tags
if tags and (ZENML_GROUP_KEY in tags or ZENML_SCHEMA_NAME in tags):
logger.info(
"Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
secret_property.name,
tags.get(ZENML_GROUP_KEY),
)
self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_secret(self, secret_name)
Delete an existing secret. by name.
In Azure a secret is a single k-v pair. Within ZenML a secret is a collection of k-v pairs. As such, deleting a secret will iterate through all secrets and delete the ones with the secret_name as label.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
In Azure a secret is a single k-v pair. Within ZenML a secret is a
collection of k-v pairs. As such, deleting a secret will iterate through
all secrets and delete the ones with the secret_name as label.
Args:
secret_name: the name of the secret to delete
"""
self._ensure_client_connected(self.key_vault_name)
# Go through all Azure secrets and delete the ones with the secret_name
# as label.
for secret_property in self.CLIENT.list_properties_of_secrets():
response = self.CLIENT.get_secret(secret_property.name)
tags = response.properties.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
self.CLIENT.begin_delete_secret(secret_property.name).result()
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
self._ensure_client_connected(self.key_vault_name)
set_of_secrets = set()
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
return list(set_of_secrets)
get_secret(self, secret_name)
Get a secret by its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the secret does not exist |
ValueError |
if the secret is named 'name' |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
RuntimeError: if the secret does not exist
ValueError: if the secret is named 'name'
"""
self._ensure_client_connected(self.key_vault_name)
secret_contents = {}
zenml_schema_name = ""
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")
if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")
response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value
zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
if not secret_contents:
raise RuntimeError(f"No secrets found within the {secret_name}")
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
return secret_schema(**secret_contents)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self._ensure_client_connected(self.key_vault_name)
if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)
self.update_secret(secret)
update_secret(self, secret)
Update an existing secret by creating new versions of the existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
"""
self._ensure_client_connected(self.key_vault_name)
for key, value in secret.content.items():
encoded_key = base64.b64encode(
f"{secret.name}-{key}".encode()
).hex()
azure_secret_name = f"zenml-{encoded_key}"
self.CLIENT.set_secret(azure_secret_name, value)
self.CLIENT.update_secret_properties(
azure_secret_name,
tags={
ZENML_GROUP_KEY: secret.name,
ZENML_KEY_NAME: key,
ZENML_SCHEMA_NAME: secret.TYPE,
},
)
logger.debug("Wrote secret: %s", azure_secret_name)
step_operators
special
Initialization of AzureML Step Operator integration.
azureml_step_operator
Implementation of the ZenML AzureML Step Operator.
AzureMLStepOperator (BaseStepOperator)
pydantic-model
Step operator to run a step on AzureML.
This class defines code that can set up an AzureML environment and run the ZenML entrypoint command in it.
Attributes:
Name | Type | Description |
---|---|---|
subscription_id |
str |
The Azure account's subscription ID |
resource_group |
str |
The resource group to which the AzureML workspace is deployed. |
workspace_name |
str |
The name of the AzureML Workspace. |
compute_target_name |
str |
The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already. |
environment_name |
Optional[str] |
The name of the environment if there already exists one. |
docker_base_image |
Optional[str] |
The custom docker base image that the environment should use. |
tenant_id |
Optional[str] |
The Azure Tenant ID. |
service_principal_id |
Optional[str] |
The ID for the service principal that is created to allow apps to access secure resources. |
service_principal_password |
Optional[str] |
Password for the service principal. |
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator):
"""Step operator to run a step on AzureML.
This class defines code that can set up an AzureML environment and run the
ZenML entrypoint command in it.
Attributes:
subscription_id: The Azure account's subscription ID
resource_group: The resource group to which the AzureML workspace
is deployed.
workspace_name: The name of the AzureML Workspace.
compute_target_name: The name of the configured ComputeTarget.
An instance of it has to be created on the portal if it doesn't
exist already.
environment_name: The name of the environment if there
already exists one.
docker_base_image: The custom docker base image that the
environment should use.
tenant_id: The Azure Tenant ID.
service_principal_id: The ID for the service principal that is created
to allow apps to access secure resources.
service_principal_password: Password for the service principal.
"""
subscription_id: str
resource_group: str
workspace_name: str
compute_target_name: str
# Environment
environment_name: Optional[str] = None
docker_base_image: Optional[str] = None
# Service principal authentication
# https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
tenant_id: Optional[str] = None
service_principal_id: Optional[str] = None
service_principal_password: Optional[str] = None
# Class Configuration
FLAVOR: ClassVar[str] = AZUREML_STEP_OPERATOR_FLAVOR
def _get_authentication(self) -> Optional[AbstractAuthentication]:
"""Returns the authentication object for the AzureML environment.
Returns:
The authentication object for the AzureML environment.
"""
if (
self.tenant_id
and self.service_principal_id
and self.service_principal_password
):
return ServicePrincipalAuthentication(
tenant_id=self.tenant_id,
service_principal_id=self.service_principal_id,
service_principal_password=self.service_principal_password,
)
return None
def _prepare_environment(
self, workspace: Workspace, requirements: List[str], run_name: str
) -> Environment:
"""Prepares the environment in which Azure will run all jobs.
Args:
workspace: The AzureML Workspace that has configuration
for a storage account, container registry among other
things.
requirements: The list of requirements to be installed
in the environment.
run_name: The name of the pipeline run that can be used
for naming environments and runs.
Returns:
The AzureML Environment object.
"""
if self.environment_name:
environment = Environment.get(
workspace=workspace, name=self.environment_name
)
if not environment.python.conda_dependencies:
environment.python.conda_dependencies = (
CondaDependencies.create(
python_version=ZenMLEnvironment.python_version()
)
)
for requirement in requirements:
environment.python.conda_dependencies.add_pip_package(
requirement
)
else:
environment = Environment(name=f"zenml-{run_name}")
environment.python.conda_dependencies = CondaDependencies.create(
pip_packages=requirements,
python_version=ZenMLEnvironment.python_version(),
)
if self.docker_base_image:
# replace the default azure base image
environment.docker.base_image = self.docker_base_image
environment_variables = {
"ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
}
# set credentials to access azure storage
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
environment_variables[key] = value
environment_variables[
ENV_ZENML_CONFIG_PATH
] = f"./{CONTAINER_ZENML_CONFIG_DIR}"
environment.environment_variables = environment_variables
return environment
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on AzureML.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
resource_configuration: The resource configuration for this step.
"""
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the AzureML step operator. If you want to run this step "
"operator on specific resources, you can do so by creating an "
"Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
"with a specific machine type and then updating this step "
"operator: `zenml step-operator update %s "
"--compute_target_name=<COMPUTE_TARGET_NAME>`",
self.name,
)
workspace = Workspace.get(
subscription_id=self.subscription_id,
resource_group=self.resource_group,
name=self.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
try:
# Save a copy of the current global configuration with the
# active profile contents into the build context, to have
# the configured stacks accessible from within the Azure ML
# environment.
load_config_path = PurePosixPath(f"./{CONTAINER_ZENML_CONFIG_DIR}")
GlobalConfiguration().copy_active_configuration(
config_path,
load_config_path=load_config_path,
)
environment = self._prepare_environment(
workspace=workspace,
requirements=requirements,
run_name=run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(workspace=workspace, name=pipeline_name)
run = experiment.submit(config=run_config)
finally:
# Clean up the temporary build files
fileio.rmtree(config_path)
run.display_name = run_name
run.wait_for_completion(show_output=True)
launch(self, pipeline_name, run_name, requirements, entrypoint_command, resource_configuration)
Launches a step on AzureML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
resource_configuration |
ResourceConfiguration |
The resource configuration for this step. |
required |
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on AzureML.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
resource_configuration: The resource configuration for this step.
"""
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the AzureML step operator. If you want to run this step "
"operator on specific resources, you can do so by creating an "
"Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
"with a specific machine type and then updating this step "
"operator: `zenml step-operator update %s "
"--compute_target_name=<COMPUTE_TARGET_NAME>`",
self.name,
)
workspace = Workspace.get(
subscription_id=self.subscription_id,
resource_group=self.resource_group,
name=self.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
try:
# Save a copy of the current global configuration with the
# active profile contents into the build context, to have
# the configured stacks accessible from within the Azure ML
# environment.
load_config_path = PurePosixPath(f"./{CONTAINER_ZENML_CONFIG_DIR}")
GlobalConfiguration().copy_active_configuration(
config_path,
load_config_path=load_config_path,
)
environment = self._prepare_environment(
workspace=workspace,
requirements=requirements,
run_name=run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(workspace=workspace, name=pipeline_name)
run = experiment.submit(config=run_config)
finally:
# Clean up the temporary build files
fileio.rmtree(config_path)
run.display_name = run_name
run.wait_for_completion(show_output=True)
constants
Constants for ZenML integrations.
dash
special
Initialization of the Dash integration.
DashIntegration (Integration)
Definition of Dash integration for ZenML.
Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
"""Definition of Dash integration for ZenML."""
NAME = DASH
REQUIREMENTS = [
"dash>=2.0.0",
"dash-cytoscape>=0.3.0",
"dash-bootstrap-components>=1.0.1",
"jupyter-dash>=0.4.2",
]
visualizers
special
Initialization of the Pipeline Run Visualizer.
pipeline_run_lineage_visualizer
Implementation of the pipeline run lineage visualizer.
PipelineRunLineageVisualizer (BasePipelineRunVisualizer)
Implementation of a lineage diagram via the dash and dash-cytoscape libraries.
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BasePipelineRunVisualizer):
"""Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""
ARTIFACT_PREFIX = "artifact_"
STEP_PREFIX = "step_"
STATUS_CLASS_MAPPING = {
ExecutionStatus.CACHED: "green",
ExecutionStatus.FAILED: "red",
ExecutionStatus.RUNNING: "yellow",
ExecutionStatus.COMPLETED: "blue",
}
def visualize(
self,
object: PipelineRunView,
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Args:
object: The pipeline run to visualize.
magic: If True, the visualization is rendered in a magic mode.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
The Dash application.
"""
external_stylesheets = [
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
]
if magic:
if Environment.in_notebook:
# Only import jupyter_dash in this case
from jupyter_dash import JupyterDash # noqa
JupyterDash.infer_jupyter_proxy_config()
app = JupyterDash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = "inline"
else:
cli_utils.warning(
"Cannot set magic flag in non-notebook environments."
)
else:
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = None
nodes, edges, first_step_id = [], [], None
first_step_id = None
for step in object.steps:
step_output_artifacts = list(step.outputs.values())
execution_id = (
step_output_artifacts[0].producer_step.id
if step_output_artifacts
else step.id
)
step_id = self.STEP_PREFIX + str(step.id)
if first_step_id is None:
first_step_id = step_id
nodes.append(
{
"data": {
"id": step_id,
"execution_id": execution_id,
"label": f"{execution_id} / {step.entrypoint_name}",
"entrypoint_name": step.entrypoint_name, # redundant for consistency
"name": step.name, # redundant for consistency
"type": "step",
"parameters": step.parameters,
"inputs": {k: v.uri for k, v in step.inputs.items()},
"outputs": {k: v.uri for k, v in step.outputs.items()},
},
"classes": self.STATUS_CLASS_MAPPING[step.status],
}
)
for artifact_name, artifact in step.outputs.items():
nodes.append(
{
"data": {
"id": self.ARTIFACT_PREFIX + str(artifact.id),
"execution_id": artifact.id,
"label": f"{artifact.id} / {artifact_name} ("
f"{artifact.data_type})",
"type": "artifact",
"name": artifact_name,
"is_cached": artifact.is_cached,
"artifact_type": artifact.type,
"artifact_data_type": artifact.data_type,
"parent_step_id": artifact.parent_step_id,
"producer_step_id": artifact.producer_step.id,
"uri": artifact.uri,
},
"classes": f"rectangle "
f"{self.STATUS_CLASS_MAPPING[step.status]}",
}
)
edges.append(
{
"data": {
"source": self.STEP_PREFIX + str(step.id),
"target": self.ARTIFACT_PREFIX + str(artifact.id),
},
"classes": f"edge-arrow "
f"{self.STATUS_CLASS_MAPPING[step.status]}"
+ (" dashed" if artifact.is_cached else " solid"),
}
)
for artifact_name, artifact in step.inputs.items():
edges.append(
{
"data": {
"source": self.ARTIFACT_PREFIX + str(artifact.id),
"target": self.STEP_PREFIX + str(step.id),
},
"classes": "edge-arrow "
+ (
f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
if artifact.is_cached
else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
),
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h1"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph.
Args:
data_list: The selected node data.
Returns:
str: The selected node data.
"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
if data["type"] == "artifact":
for item in [
"artifact_data_type",
"is_cached",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += "### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Params:"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout.
Args:
n_clicks: The number of clicks on the reset button.
Returns:
The zoom and the elements.
"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
if mode is not None:
app.run_server(mode=mode)
app.run_server()
return app
visualize(self, object, magic=False, *args, **kwargs)
Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineRunView |
The pipeline run to visualize. |
required |
magic |
bool |
If True, the visualization is rendered in a magic mode. |
False |
*args |
Any |
Additional positional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Dash |
The Dash application. |
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
self,
object: PipelineRunView,
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Args:
object: The pipeline run to visualize.
magic: If True, the visualization is rendered in a magic mode.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
The Dash application.
"""
external_stylesheets = [
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
]
if magic:
if Environment.in_notebook:
# Only import jupyter_dash in this case
from jupyter_dash import JupyterDash # noqa
JupyterDash.infer_jupyter_proxy_config()
app = JupyterDash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = "inline"
else:
cli_utils.warning(
"Cannot set magic flag in non-notebook environments."
)
else:
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = None
nodes, edges, first_step_id = [], [], None
first_step_id = None
for step in object.steps:
step_output_artifacts = list(step.outputs.values())
execution_id = (
step_output_artifacts[0].producer_step.id
if step_output_artifacts
else step.id
)
step_id = self.STEP_PREFIX + str(step.id)
if first_step_id is None:
first_step_id = step_id
nodes.append(
{
"data": {
"id": step_id,
"execution_id": execution_id,
"label": f"{execution_id} / {step.entrypoint_name}",
"entrypoint_name": step.entrypoint_name, # redundant for consistency
"name": step.name, # redundant for consistency
"type": "step",
"parameters": step.parameters,
"inputs": {k: v.uri for k, v in step.inputs.items()},
"outputs": {k: v.uri for k, v in step.outputs.items()},
},
"classes": self.STATUS_CLASS_MAPPING[step.status],
}
)
for artifact_name, artifact in step.outputs.items():
nodes.append(
{
"data": {
"id": self.ARTIFACT_PREFIX + str(artifact.id),
"execution_id": artifact.id,
"label": f"{artifact.id} / {artifact_name} ("
f"{artifact.data_type})",
"type": "artifact",
"name": artifact_name,
"is_cached": artifact.is_cached,
"artifact_type": artifact.type,
"artifact_data_type": artifact.data_type,
"parent_step_id": artifact.parent_step_id,
"producer_step_id": artifact.producer_step.id,
"uri": artifact.uri,
},
"classes": f"rectangle "
f"{self.STATUS_CLASS_MAPPING[step.status]}",
}
)
edges.append(
{
"data": {
"source": self.STEP_PREFIX + str(step.id),
"target": self.ARTIFACT_PREFIX + str(artifact.id),
},
"classes": f"edge-arrow "
f"{self.STATUS_CLASS_MAPPING[step.status]}"
+ (" dashed" if artifact.is_cached else " solid"),
}
)
for artifact_name, artifact in step.inputs.items():
edges.append(
{
"data": {
"source": self.ARTIFACT_PREFIX + str(artifact.id),
"target": self.STEP_PREFIX + str(step.id),
},
"classes": "edge-arrow "
+ (
f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
if artifact.is_cached
else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
),
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h1"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph.
Args:
data_list: The selected node data.
Returns:
str: The selected node data.
"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
if data["type"] == "artifact":
for item in [
"artifact_data_type",
"is_cached",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += "### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Params:"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout.
Args:
n_clicks: The number of clicks on the reset button.
Returns:
The zoom and the elements.
"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
if mode is not None:
app.run_server(mode=mode)
app.run_server()
return app
deepchecks
special
Deepchecks integration for ZenML.
The Deepchecks integration provides a way to validate your data in your pipelines. It includes a way to detect data anomalies and define checks to ensure quality of data.
The integration includes custom materializers to store Deepchecks SuiteResults
and
a visualizer to visualize the results in an easy way on a notebook and in your
browser.
DeepchecksIntegration (Integration)
Definition of Deepchecks integration for ZenML.
Source code in zenml/integrations/deepchecks/__init__.py
class DeepchecksIntegration(Integration):
"""Definition of [Deepchecks](https://github.com/deepchecks/deepchecks) integration for ZenML."""
NAME = DEEPCHECKS
REQUIREMENTS = ["deepchecks[vision]==0.8.0", "torchvision==0.11.2"]
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.deepchecks import materializers # noqa
from zenml.integrations.deepchecks import visualizers # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Deepchecks integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
activate()
staticmethod
Activate the Deepchecks integration.
Source code in zenml/integrations/deepchecks/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.deepchecks import materializers # noqa
from zenml.integrations.deepchecks import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Deepchecks integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/deepchecks/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Deepchecks integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
data_validators
special
Initialization of the Deepchecks data validator for ZenML.
deepchecks_data_validator
Implementation of the Deepchecks data validator.
DeepchecksDataValidator (BaseDataValidator)
pydantic-model
Deepchecks data validator stack component.
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
class DeepchecksDataValidator(BaseDataValidator):
"""Deepchecks data validator stack component."""
# Class Configuration
FLAVOR: ClassVar[str] = DEEPCHECKS_DATA_VALIDATOR_FLAVOR
NAME: ClassVar[str] = "Deepchecks"
@staticmethod
def _split_checks(
check_list: Sequence[str],
) -> Tuple[Sequence[str], Sequence[str]]:
"""Split a list of check identifiers in two lists, one for tabular and one for computer vision checks.
Args:
check_list: A list of check identifiers.
Returns:
List of tabular check identifiers and list of computer vision
check identifiers.
"""
tabular_checks = list(
filter(
lambda check: DeepchecksValidationCheck.is_tabular_check(check),
check_list,
)
)
vision_checks = list(
filter(
lambda check: DeepchecksValidationCheck.is_vision_check(check),
check_list,
)
)
return tabular_checks, vision_checks
# flake8: noqa: C901
@classmethod
def _create_and_run_check_suite(
cls,
check_enum: Type[DeepchecksValidationCheck],
reference_dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[
Union[pd.DataFrame, DataLoader[Any]]
] = None,
model: Optional[Union[ClassifierMixin, Module]] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
) -> SuiteResult:
"""Create and run a Deepchecks check suite corresponding to the input parameters.
This method contains generic logic common to all Deepchecks data
validator methods that validates the input arguments and uses them to
generate and run a Deepchecks check suite.
Args:
check_enum: ZenML enum type grouping together Deepchecks checks with
the same characteristics. This is used to generate a default
list of checks, if a custom list isn't provided via the
`check_list` argument.
reference_dataset: Primary (reference) dataset argument used during
validation.
comparison_dataset: Optional secondary (comparison) dataset argument
used during comparison checks.
model: Optional model argument used during validation.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the list of Deepchecks checks to be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
Returns:
Deepchecks SuiteResult object with the Suite run results.
Raises:
TypeError: If the datasets, model and check list arguments combine
data types and/or checks from different categories (tabular and
computer vision).
"""
# Detect what type of check to perform (tabular or computer vision) from
# the dataset/model datatypes and the check list. At the same time,
# validate the combination of data types used for dataset and model
# arguments and the check list.
is_tabular = False
is_vision = False
for dataset in [reference_dataset, comparison_dataset]:
if dataset is None:
continue
if isinstance(dataset, pd.DataFrame):
is_tabular = True
elif isinstance(dataset, DataLoader):
is_vision = True
else:
raise TypeError(
f"Unsupported dataset data type found: {type(dataset)}. "
f"Supported data types are {str(pd.DataFrame)} for tabular "
f"data and {str(DataLoader)} for computer vision data."
)
if model:
if isinstance(model, ClassifierMixin):
is_tabular = True
elif isinstance(model, Module):
is_vision = True
else:
raise TypeError(
f"Unsupported model data type found: {type(model)}. "
f"Supported data types are {str(ClassifierMixin)} for "
f"tabular data and {str(Module)} for computer vision "
f"data."
)
if is_tabular and is_vision:
raise TypeError(
f"Tabular and computer vision data types used for datasets and "
f"models cannot be mixed. They must all belong to the same "
f"category. Supported data types for tabular data are "
f"{str(pd.DataFrame)} for datasets and {str(ClassifierMixin)} "
f"for models. Supported data types for computer vision data "
f"are {str(pd.DataFrame)} for datasets and and {str(Module)} "
f"for models."
)
if not check_list:
# default to executing all the checks listed in the supplied
# checks enum type if a custom check list is not supplied
tabular_checks, vision_checks = cls._split_checks(
check_enum.values()
)
if is_tabular:
check_list = tabular_checks
vision_checks = []
else:
check_list = vision_checks
tabular_checks = []
else:
tabular_checks, vision_checks = cls._split_checks(check_list)
if tabular_checks and vision_checks:
raise TypeError(
f"The check list cannot mix tabular checks "
f"({tabular_checks}) and computer vision checks ("
f"{vision_checks})."
)
if is_tabular and vision_checks:
raise TypeError(
f"Tabular data types used for datasets and models can only "
f"be used with tabular validation checks. The following "
f"computer vision checks included in the check list are "
f"not valid: {vision_checks}."
)
if is_vision and tabular_checks:
raise TypeError(
f"Computer vision data types used for datasets and models "
f"can only be used with computer vision validation checks. "
f"The following tabular checks included in the check list "
f"are not valid: {tabular_checks}."
)
check_classes = map(
lambda check: (
check,
check_enum.get_check_class(check),
),
check_list,
)
# use the pipeline name and the step name to generate a unique suite
# name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
suite_name = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
# if not running inside a pipeline step, use random values
suite_name = f"suite_{random_str(5)}"
if is_tabular:
dataset_class = TabularData
suite_class = TabularSuite
full_suite = full_tabular_suite()
else:
dataset_class = VisionData
suite_class = VisionSuite
full_suite = full_vision_suite()
train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
test_dataset = None
if comparison_dataset is not None:
test_dataset = dataset_class(comparison_dataset, **dataset_kwargs)
suite = suite_class(name=suite_name)
# Some Deepchecks checks require a minimum configuration such as
# conditions to be configured (see https://docs.deepchecks.com/stable/user-guide/general/customizations/examples/plot_configure_check_conditions.html#sphx-glr-user-guide-general-customizations-examples-plot-configure-check-conditions-py)
# for their execution to have meaning. For checks that don't have
# custom configuration attributes explicitly specified in the
# `check_kwargs` input parameter, we use the default check
# instances extracted from the full suite shipped with Deepchecks.
default_checks = {
check.__class__: check for check in full_suite.checks.values()
}
for check_name, check_class in check_classes:
extra_kwargs = check_kwargs.get(check_name, {})
default_check = default_checks.get(check_class)
check: BaseCheck
if extra_kwargs or not default_check:
check = check_class(**check_kwargs)
else:
check = default_check
# extract the condition kwargs from the check kwargs
for arg_name, condition_kwargs in extra_kwargs.items():
if not arg_name.startswith("condition_") or not isinstance(
condition_kwargs, dict
):
continue
condition_method = getattr(check, f"add_{arg_name}", None)
if not condition_method or not callable(condition_method):
logger.warning(
f"Deepchecks check type {check.__class__} has no "
f"condition named {arg_name}. Ignoring the check "
f"argument."
)
continue
condition_method(**condition_kwargs)
suite.add(check)
return suite.run(
train_dataset=train_dataset,
test_dataset=test_dataset,
model=model,
**run_kwargs,
)
def data_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> SuiteResult:
"""Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the `comparison_dataset` argument.
The `check_list` argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
`DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
data integrity checks will be performed on the input data. See
`DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available data drift checks will be performed on the input
data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
comparison_dataset: Optional second dataset to be used for data
comparison checks (e.g data drift checks).
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
`DeepchecksDataIntegrityCheck` enum values should be used for
single data validation checks and `DeepchecksDataDriftCheck`
enum values for data comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksDataIntegrityCheck
else:
check_enum = DeepchecksDataDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
def model_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
model: Union[ClassifierMixin, Module],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion
matrix validation, performance reports, model error analyses, etc).
A second dataset is required for model performance comparison tests
(i.e. tests that identify changes in a model behavior by comparing how
it performs on two different datasets).
The `check_list` argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
`DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
model: Target model to be validated.
comparison_dataset: Optional second dataset to be used for model
comparison checks.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
`DeepchecksModelValidationCheck` enum values should be used for
model validation checks and `DeepchecksModelDriftCheck` enum
values for model comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksModelValidationCheck
else:
check_enum = DeepchecksModelDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
data_validation(self, dataset, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)
Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the comparison_dataset
argument.
The check_list
argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
DeepchecksDataIntegrityCheck
and DeepchecksDataDriftCheck
enum
values. If omitted:
-
if the
comparison_dataset
is omitted, a suite with all available data integrity checks will be performed on the input data. SeeDeepchecksDataIntegrityCheck
for a list of Deepchecks builtin checks that are compatible with this method. -
if the
comparison_dataset
is supplied, a suite with all available data drift checks will be performed on the input data. SeeDeepchecksDataDriftCheck
for a list of Deepchecks builtin checks that are compatible with this method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]] |
Target dataset to be validated. |
required |
comparison_dataset |
Optional[Any] |
Optional second dataset to be used for data comparison checks (e.g data drift checks). |
None |
check_list |
Optional[Sequence[str]] |
Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
|
None |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
{} |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
{} |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
{} |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks SuiteResult with the results of the validation. |
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def data_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> SuiteResult:
"""Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the `comparison_dataset` argument.
The `check_list` argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
`DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
data integrity checks will be performed on the input data. See
`DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available data drift checks will be performed on the input
data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
comparison_dataset: Optional second dataset to be used for data
comparison checks (e.g data drift checks).
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
`DeepchecksDataIntegrityCheck` enum values should be used for
single data validation checks and `DeepchecksDataDriftCheck`
enum values for data comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksDataIntegrityCheck
else:
check_enum = DeepchecksDataDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
model_validation(self, dataset, model, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)
Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion matrix validation, performance reports, model error analyses, etc). A second dataset is required for model performance comparison tests (i.e. tests that identify changes in a model behavior by comparing how it performs on two different datasets).
The check_list
argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
DeepchecksModelValidationCheck
and DeepchecksModelDriftCheck
enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]] |
Target dataset to be validated. |
required |
model |
Union[sklearn.base.ClassifierMixin, torch.nn.modules.module.Module] |
Target model to be validated. |
required |
comparison_dataset |
Optional[Any] |
Optional second dataset to be used for model comparison checks. |
None |
check_list |
Optional[Sequence[str]] |
Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
|
None |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor. |
{} |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
{} |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
{} |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
Any |
A Deepchecks SuiteResult with the results of the validation. |
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def model_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
model: Union[ClassifierMixin, Module],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion
matrix validation, performance reports, model error analyses, etc).
A second dataset is required for model performance comparison tests
(i.e. tests that identify changes in a model behavior by comparing how
it performs on two different datasets).
The `check_list` argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
`DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
model: Target model to be validated.
comparison_dataset: Optional second dataset to be used for model
comparison checks.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
`DeepchecksModelValidationCheck` enum values should be used for
model validation checks and `DeepchecksModelDriftCheck` enum
values for model comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksModelValidationCheck
else:
check_enum = DeepchecksModelDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
materializers
special
Deepchecks materializers.
deepchecks_dataset_materializer
Implementation of Deepchecks dataset materializer.
DeepchecksDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from Deepchecks dataset.
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
class DeepchecksDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from Deepchecks dataset."""
ASSOCIATED_TYPES = (Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads pandas dataframes and creates deepchecks.Dataset from it.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks Dataset.
"""
super().handle_input(data_type)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
df = pandas_materializer.handle_input(data_type)
# Recreate from pandas dataframe
return Dataset(df)
def handle_return(self, df: Dataset) -> None:
"""Serializes pandas dataframe within a Dataset object.
Args:
df: A deepchecks.Dataset object.
"""
super().handle_return(df)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
pandas_materializer.handle_return(df.data)
handle_input(self, data_type)
Reads pandas dataframes and creates deepchecks.Dataset from it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Dataset |
A Deepchecks Dataset. |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads pandas dataframes and creates deepchecks.Dataset from it.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks Dataset.
"""
super().handle_input(data_type)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
df = pandas_materializer.handle_input(data_type)
# Recreate from pandas dataframe
return Dataset(df)
handle_return(self, df)
Serializes pandas dataframe within a Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
Dataset |
A deepchecks.Dataset object. |
required |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_return(self, df: Dataset) -> None:
"""Serializes pandas dataframe within a Dataset object.
Args:
df: A deepchecks.Dataset object.
"""
super().handle_return(df)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
pandas_materializer.handle_return(df.data)
deepchecks_results_materializer
Implementation of Deepchecks suite results materializer.
DeepchecksResultMaterializer (BaseMaterializer)
Materializer to read data to and from CheckResult and SuiteResult objects.
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
class DeepchecksResultMaterializer(BaseMaterializer):
"""Materializer to read data to and from CheckResult and SuiteResult objects."""
ASSOCIATED_TYPES = (
CheckResult,
SuiteResult,
)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
def handle_input(
self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
"""Reads a Deepchecks check or suite result from a serialized JSON file.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks CheckResult or SuiteResult.
Raises:
RuntimeError: if the input data type is not supported.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
json_res = io_utils.read_file_contents_as_string(filepath)
if data_type == SuiteResult:
res = SuiteResult.from_json(json_res)
elif data_type == CheckResult:
res = CheckResult.from_json(json_res)
else:
raise RuntimeError(f"Unknown data type: {data_type}")
return res
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Creates a JSON serialization for a CheckResult or SuiteResult.
Args:
result: A Deepchecks CheckResult or SuiteResult.
"""
super().handle_return(result)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
serialized_json = result.to_json(True)
io_utils.write_file_contents_as_string(filepath, serialized_json)
handle_input(self, data_type)
Reads a Deepchecks check or suite result from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A Deepchecks CheckResult or SuiteResult. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the input data type is not supported. |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_input(
self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
"""Reads a Deepchecks check or suite result from a serialized JSON file.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks CheckResult or SuiteResult.
Raises:
RuntimeError: if the input data type is not supported.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
json_res = io_utils.read_file_contents_as_string(filepath)
if data_type == SuiteResult:
res = SuiteResult.from_json(json_res)
elif data_type == CheckResult:
res = CheckResult.from_json(json_res)
else:
raise RuntimeError(f"Unknown data type: {data_type}")
return res
handle_return(self, result)
Creates a JSON serialization for a CheckResult or SuiteResult.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A Deepchecks CheckResult or SuiteResult. |
required |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Creates a JSON serialization for a CheckResult or SuiteResult.
Args:
result: A Deepchecks CheckResult or SuiteResult.
"""
super().handle_return(result)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
serialized_json = result.to_json(True)
io_utils.write_file_contents_as_string(filepath, serialized_json)
steps
special
Initialization of the Deepchecks Standard Steps.
deepchecks_data_drift
Implementation of the Deepchecks data drift validation step.
DeepchecksDataDriftCheckStep (BaseStep)
Deepchecks data drift validator step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStep(BaseStep):
"""Deepchecks data drift validator step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
config: DeepchecksDataDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data drift validator step.
Args:
reference_dataset: Reference dataset for the data drift check.
target_dataset: Target dataset to be used for the data drift check.
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]] |
Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data drift validator step.
Attributes:
check_list: Optional list of DeepchecksDataDriftCheck identifiers
specifying the subset of Deepchecks data drift checks to be
performed. If not supplied, the entire set of data drift checks will
be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, config)
Main entrypoint for the Deepchecks data drift validator step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
Reference dataset for the data drift check. |
required |
target_dataset |
DataFrame |
Target dataset to be used for the data drift check. |
required |
config |
DeepchecksDataDriftCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
config: DeepchecksDataDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data drift validator step.
Args:
reference_dataset: Reference dataset for the data drift check.
target_dataset: Target dataset to be used for the data drift check.
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksDataDriftCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]] |
Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data drift validator step.
Attributes:
check_list: Optional list of DeepchecksDataDriftCheck identifiers
specifying the subset of Deepchecks data drift checks to be
performed. If not supplied, the entire set of data drift checks will
be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_drift_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.
The returned DeepchecksDataDriftCheckStep can be used in a pipeline to run data drift checks on two input pd.DataFrame and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksDataDriftCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksDataDriftCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def deepchecks_data_drift_check_step(
step_name: str,
config: DeepchecksDataDriftCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.
The returned DeepchecksDataDriftCheckStep can be used in a pipeline to
run data drift checks on two input pd.DataFrame and return the results
as a Deepchecks SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksDataDriftCheckStep step instance
"""
return clone_step(DeepchecksDataDriftCheckStep, step_name)(config=config)
deepchecks_data_integrity
Implementation of the Deepchecks data integrity validation step.
DeepchecksDataIntegrityCheckStep (BaseStep)
Deepchecks data integrity validator step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStep(BaseStep):
"""Deepchecks data integrity validator step."""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: DeepchecksDataIntegrityCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data integrity validator step.
Args:
dataset: a Pandas DataFrame to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data integrity validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]] |
Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data integrity validator step.
Attributes:
check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
specifying the subset of Deepchecks data integrity checks to be
performed. If not supplied, the entire set of data integrity checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, config)
Main entrypoint for the Deepchecks data integrity validator step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a Pandas DataFrame to validate |
required |
config |
DeepchecksDataIntegrityCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: DeepchecksDataIntegrityCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data integrity validator step.
Args:
dataset: a Pandas DataFrame to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksDataIntegrityCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data integrity validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]] |
Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data integrity validator step.
Attributes:
check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
specifying the subset of Deepchecks data integrity checks to be
performed. If not supplied, the entire set of data integrity checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_integrity_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.
The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to run data integrity checks on an input pd.DataFrame and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksDataIntegrityCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksDataIntegrityCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def deepchecks_data_integrity_check_step(
step_name: str,
config: DeepchecksDataIntegrityCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.
The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to
run data integrity checks on an input pd.DataFrame and return the results
as a Deepchecks SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksDataIntegrityCheckStep step instance
"""
return clone_step(DeepchecksDataIntegrityCheckStep, step_name)(
config=config
)
deepchecks_model_drift
Implementation of the Deepchecks model drift validation step.
DeepchecksModelDriftCheckStep (BaseStep)
Deepchecks model drift step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStep(BaseStep):
"""Deepchecks model drift step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model drift step.
Args:
reference_dataset: Reference dataset for the model drift check.
target_dataset: Target dataset to be used for the model drift check.
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]] |
Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model drift validator step.
Attributes:
check_list: Optional list of DeepchecksModelDriftCheck identifiers
specifying the subset of Deepchecks model drift checks to be
performed. If not supplied, the entire set of model drift checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, model, config)
Main entrypoint for the Deepchecks model drift step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
Reference dataset for the model drift check. |
required |
target_dataset |
DataFrame |
Target dataset to be used for the model drift check. |
required |
model |
ClassifierMixin |
a scikit-learn model to validate |
required |
config |
DeepchecksModelDriftCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model drift step.
Args:
reference_dataset: Reference dataset for the model drift check.
target_dataset: Target dataset to be used for the model drift check.
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksModelDriftCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]] |
Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model drift validator step.
Attributes:
check_list: Optional list of DeepchecksModelDriftCheck identifiers
specifying the subset of Deepchecks model drift checks to be
performed. If not supplied, the entire set of model drift checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_drift_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.
The returned DeepchecksModelDriftCheckStep can be used in a pipeline to run model drift checks on two input pd.DataFrame datasets and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksModelDriftCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksModelDriftCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def deepchecks_model_drift_check_step(
step_name: str,
config: DeepchecksModelDriftCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.
The returned DeepchecksModelDriftCheckStep can be used in a pipeline to
run model drift checks on two input pd.DataFrame datasets and an input
scikit-learn ClassifierMixin model and return the results as a Deepchecks
SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksModelDriftCheckStep step instance
"""
return clone_step(DeepchecksModelDriftCheckStep, step_name)(config=config)
deepchecks_model_validation
Implementation of the Deepchecks model validation validation step.
DeepchecksModelValidationCheckStep (BaseStep)
Deepchecks model validation step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStep(BaseStep):
"""Deepchecks model validation step."""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelValidationCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model validation step.
Args:
dataset: a Pandas DataFrame to use for the validation
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model validation validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]] |
Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model validation validator step.
Attributes:
check_list: Optional list of DeepchecksModelValidationCheck identifiers
specifying the subset of Deepchecks model validation checks to be
performed. If not supplied, the entire set of model validation checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, model, config)
Main entrypoint for the Deepchecks model validation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a Pandas DataFrame to use for the validation |
required |
model |
ClassifierMixin |
a scikit-learn model to validate |
required |
config |
DeepchecksModelValidationCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelValidationCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model validation step.
Args:
dataset: a Pandas DataFrame to use for the validation
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksModelValidationCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model validation validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]] |
Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model validation validator step.
Attributes:
check_list: Optional list of DeepchecksModelValidationCheck identifiers
specifying the subset of Deepchecks model validation checks to be
performed. If not supplied, the entire set of model validation checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_validation_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.
The returned DeepchecksModelValidationCheckStep can be used in a pipeline to run model validation checks on an input pd.DataFrame dataset and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksModelValidationCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksModelValidationCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def deepchecks_model_validation_check_step(
step_name: str,
config: DeepchecksModelValidationCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.
The returned DeepchecksModelValidationCheckStep can be used in a pipeline to
run model validation checks on an input pd.DataFrame dataset and an input
scikit-learn ClassifierMixin model and return the results as a Deepchecks
SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksModelValidationCheckStep step instance
"""
return clone_step(DeepchecksModelValidationCheckStep, step_name)(
config=config
)
validation_checks
Definition of the Deepchecks validation check types.
DeepchecksDataDriftCheck (DeepchecksValidationCheck)
Categories of Deepchecks data drift checks.
This list reflects the set of train-test validation checks provided by Deepchecks:
All these checks inherit from deepchecks.tabular.TrainTestCheck
or
deepchecks.vision.TrainTestCheck
and require two datasets as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks data drift checks.
This list reflects the set of train-test validation checks provided by
Deepchecks:
* [for tabular data](https://docs.deepchecks.com/stable/checks_gallery/tabular.html#train-test-validation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#train-test-validation)
All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
`deepchecks.vision.TrainTestCheck` and require two datasets as input.
"""
TABULAR_CATEGORY_MISMATCH_TRAIN_TEST = resolve_class(
tabular_checks.CategoryMismatchTrainTest
)
TABULAR_DATASET_SIZE_COMPARISON = resolve_class(
tabular_checks.DatasetsSizeComparison
)
TABULAR_DATE_TRAIN_TEST_LEAKAGE_DUPLICATES = resolve_class(
tabular_checks.DateTrainTestLeakageDuplicates
)
TABULAR_DATE_TRAIN_TEST_LEAKAGE_OVERLAP = resolve_class(
tabular_checks.DateTrainTestLeakageOverlap
)
TABULAR_DOMINANT_FREQUENCY_CHANGE = resolve_class(
tabular_checks.DominantFrequencyChange
)
TABULAR_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
tabular_checks.FeatureLabelCorrelationChange
)
TABULAR_INDEX_LEAKAGE = resolve_class(tabular_checks.IndexTrainTestLeakage)
TABULAR_NEW_LABEL_TRAIN_TEST = resolve_class(
tabular_checks.NewLabelTrainTest
)
TABULAR_STRING_MISMATCH_COMPARISON = resolve_class(
tabular_checks.StringMismatchComparison
)
TABULAR_TRAIN_TEST_FEATURE_DRIFT = resolve_class(
tabular_checks.TrainTestFeatureDrift
)
TABULAR_TRAIN_TEST_LABEL_DRIFT = resolve_class(
tabular_checks.TrainTestLabelDrift
)
TABULAR_TRAIN_TEST_SAMPLES_MIX = resolve_class(
tabular_checks.TrainTestSamplesMix
)
TABULAR_WHOLE_DATASET_DRIFT = resolve_class(
tabular_checks.WholeDatasetDrift
)
VISION_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
vision_checks.FeatureLabelCorrelationChange
)
VISION_HEATMAP_COMPARISON = resolve_class(vision_checks.HeatmapComparison)
VISION_IMAGE_DATASET_DRIFT = resolve_class(vision_checks.ImageDatasetDrift)
VISION_IMAGE_PROPERTY_DRIFT = resolve_class(
vision_checks.ImagePropertyDrift
)
VISION_NEW_LABELS = resolve_class(vision_checks.NewLabels)
VISION_SIMILAR_IMAGE_LEAKAGE = resolve_class(
vision_checks.SimilarImageLeakage
)
VISION_TRAIN_TEST_LABEL_DRIFT = resolve_class(
vision_checks.TrainTestLabelDrift
)
DeepchecksDataIntegrityCheck (DeepchecksValidationCheck)
Categories of Deepchecks data integrity checks.
This list reflects the set of data integrity checks provided by Deepchecks:
All these checks inherit from deepchecks.tabular.SingleDatasetCheck
or
deepchecks.vision.SingleDatasetCheck
and require a single dataset as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks data integrity checks.
This list reflects the set of data integrity checks provided by Deepchecks:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#data-integrity)
* [for computer vision](https://docs.deepchecks.com/en/stable/checks_gallery/vision.html#data-integrity)
All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
`deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
"""
TABULAR_COLUMNS_INFO = resolve_class(tabular_checks.ColumnsInfo)
TABULAR_CONFLICTING_LABELS = resolve_class(tabular_checks.ConflictingLabels)
TABULAR_DATA_DUPLICATES = resolve_class(tabular_checks.DataDuplicates)
TABULAR_FEATURE_FEATURE_CORRELATION = resolve_class(
FeatureFeatureCorrelation
)
TABULAR_FEATURE_LABEL_CORRELATION = resolve_class(
tabular_checks.FeatureLabelCorrelation
)
TABULAR_IDENTIFIER_LEAKAGE = resolve_class(tabular_checks.IdentifierLeakage)
TABULAR_IS_SINGLE_VALUE = resolve_class(tabular_checks.IsSingleValue)
TABULAR_MIXED_DATA_TYPES = resolve_class(tabular_checks.MixedDataTypes)
TABULAR_MIXED_NULLS = resolve_class(tabular_checks.MixedNulls)
TABULAR_OUTLIER_SAMPLE_DETECTION = resolve_class(
tabular_checks.OutlierSampleDetection
)
TABULAR_SPECIAL_CHARS = resolve_class(tabular_checks.SpecialCharacters)
TABULAR_STRING_LENGTH_OUT_OF_BOUNDS = resolve_class(
tabular_checks.StringLengthOutOfBounds
)
TABULAR_STRING_MISMATCH = resolve_class(tabular_checks.StringMismatch)
VISION_IMAGE_PROPERTY_OUTLIERS = resolve_class(
vision_checks.ImagePropertyOutliers
)
VISION_LABEL_PROPERTY_OUTLIERS = resolve_class(
vision_checks.LabelPropertyOutliers
)
DeepchecksModelDriftCheck (DeepchecksValidationCheck)
Categories of Deepchecks model drift checks.
This list includes a subset of the model evaluation checks provided by Deepchecks that require two datasets and a mandatory model as input:
All these checks inherit from deepchecks.tabular.TrainTestCheck
or
deepchecks.vision.TrainTestCheck
and require two datasets and a mandatory
model as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks model drift checks.
This list includes a subset of the model evaluation checks provided by
Deepchecks that require two datasets and a mandatory model as input:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)
All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
`deepchecks.vision.TrainTestCheck` and require two datasets and a mandatory
model as input.
"""
TABULAR_BOOSTING_OVERFIT = resolve_class(tabular_checks.BoostingOverfit)
TABULAR_MODEL_ERROR_ANALYSIS = resolve_class(
tabular_checks.ModelErrorAnalysis
)
TABULAR_PERFORMANCE_REPORT = resolve_class(tabular_checks.PerformanceReport)
TABULAR_SIMPLE_MODEL_COMPARISON = resolve_class(
tabular_checks.SimpleModelComparison
)
TABULAR_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
tabular_checks.TrainTestPredictionDrift
)
TABULAR_UNUSED_FEATURES = resolve_class(tabular_checks.UnusedFeatures)
VISION_CLASS_PERFORMANCE = resolve_class(vision_checks.ClassPerformance)
VISION_MODEL_ERROR_ANALYSIS = resolve_class(
vision_checks.ModelErrorAnalysis
)
VISION_SIMPLE_MODEL_COMPARISON = resolve_class(
vision_checks.SimpleModelComparison
)
VISION_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
vision_checks.TrainTestPredictionDrift
)
DeepchecksModelValidationCheck (DeepchecksValidationCheck)
Categories of Deepchecks model validation checks.
This list includes a subset of the model evaluation checks provided by Deepchecks that require a single dataset and a mandatory model as input:
All these checks inherit from deepchecks.tabular.SingleDatasetCheck
or
`deepchecks.vision.SingleDatasetCheck and require a dataset and a mandatory
model as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks model validation checks.
This list includes a subset of the model evaluation checks provided by
Deepchecks that require a single dataset and a mandatory model as input:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)
All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
`deepchecks.vision.SingleDatasetCheck and require a dataset and a mandatory
model as input.
"""
TABULAR_CALIBRATION_SCORE = resolve_class(tabular_checks.CalibrationScore)
TABULAR_CONFUSION_MATRIX_REPORT = resolve_class(
tabular_checks.ConfusionMatrixReport
)
TABULAR_MODEL_INFERENCE_TIME = resolve_class(
tabular_checks.ModelInferenceTime
)
TABULAR_REGRESSION_ERROR_DISTRIBUTION = resolve_class(
tabular_checks.RegressionErrorDistribution
)
TABULAR_REGRESSION_SYSTEMATIC_ERROR = resolve_class(
tabular_checks.RegressionSystematicError
)
TABULAR_ROC_REPORT = resolve_class(tabular_checks.RocReport)
TABULAR_SEGMENT_PERFORMANCE = resolve_class(
tabular_checks.SegmentPerformance
)
VISION_CONFUSION_MATRIX_REPORT = resolve_class(
vision_checks.ConfusionMatrixReport
)
VISION_IMAGE_SEGMENT_PERFORMANCE = resolve_class(
vision_checks.ImageSegmentPerformance
)
VISION_MEAN_AVERAGE_PRECISION_REPORT = resolve_class(
vision_checks.MeanAveragePrecisionReport
)
VISION_MEAN_AVERAGE_RECALL_REPORT = resolve_class(
vision_checks.MeanAverageRecallReport
)
VISION_ROBUSTNESS_REPORT = resolve_class(vision_checks.RobustnessReport)
VISION_SINGLE_DATASET_SCALAR_PERFORMANCE = resolve_class(
vision_checks.SingleDatasetScalarPerformance
)
DeepchecksValidationCheck (StrEnum)
Base class for all Deepchecks categories of validation checks.
This base class defines some conventions used for all enum values used to identify the various validation checks that can be performed with Deepchecks:
- enum values represent fully formed class paths pointing to Deepchecks BaseCheck subclasses
- all tabular data checks are located under the
deepchecks.tabular.checks
module sub-tree - all computer vision data checks are located under the
deepchecks.vision.checks
module sub-tree
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksValidationCheck(StrEnum):
"""Base class for all Deepchecks categories of validation checks.
This base class defines some conventions used for all enum values used to
identify the various validation checks that can be performed with
Deepchecks:
* enum values represent fully formed class paths pointing to Deepchecks
BaseCheck subclasses
* all tabular data checks are located under the
`deepchecks.tabular.checks` module sub-tree
* all computer vision data checks are located under the
`deepchecks.vision.checks` module sub-tree
"""
@classmethod
def validate_check_name(cls, check_name: str) -> None:
"""Validate a Deepchecks check identifier.
Args:
check_name: Identifies a builtin Deepchecks check. The identifier
must be formatted as `deepchecks.{tabular|vision}.checks.<...>.<class-name>`.
Raises:
ValueError: If the check identifier does not follow the convention
used by ZenML to identify Deepchecks builtin checks.
"""
if not re.match(
r"^deepchecks\.(tabular|vision)\.checks\.",
check_name,
):
raise ValueError(
f"The supplied Deepcheck check identifier does not follow the "
f"convention used by ZenML: `{check_name}`. The identifier "
f"must be formatted as `deepchecks.<tabular|vision>.checks...` "
f"and must be resolvable to a valid Deepchecks BaseCheck "
f"subclass."
)
@classmethod
def is_tabular_check(cls, check_name: str) -> bool:
"""Check if a validation check is applicable to tabular data.
Args:
check_name: Identifies a builtin Deepchecks check.
Returns:
True if the check is applicable to tabular data, otherwise False.
"""
cls.validate_check_name(check_name)
return check_name.startswith("deepchecks.tabular.")
@classmethod
def is_vision_check(cls, check_name: str) -> bool:
"""Check if a validation check is applicable to computer vision data.
Args:
check_name: Identifies a builtin Deepchecks check.
Returns:
True if the check is applicable to compute vision data, otherwise
False.
"""
cls.validate_check_name(check_name)
return check_name.startswith("deepchecks.vision.")
@classmethod
def get_check_class(cls, check_name: str) -> Type[BaseCheck]:
"""Get the Deepchecks check class associated with an enum value or a custom check name.
Args:
check_name: Identifies a builtin Deepchecks check. The identifier
must be formatted as `deepchecks.{tabular|vision}.checks.<class-name>`
and must be resolvable to a valid Deepchecks BaseCheck class.
Returns:
The Deepchecks check class associated with this enum value.
Raises:
ValueError: If the check name could not be converted to a valid
Deepchecks check class. This can happen for example if the enum
values fall out of sync with the Deepchecks code base or if a
custom check name is supplied that cannot be resolved to a valid
Deepchecks BaseCheck class.
"""
cls.validate_check_name(check_name)
try:
check_class = import_class_by_path(check_name)
except AttributeError:
raise ValueError(
f"Could not map the `{check_name}` check identifier to a valid "
f"Deepchecks check class."
)
if not issubclass(check_class, BaseCheck):
raise ValueError(
f"The `{check_name}` check identifier is mapped to an invalid "
f"data type. Expected a {str(BaseCheck)} subclass, but instead "
f"got: {str(check_class)}."
)
if check_name not in cls.values():
logger.warning(
f"You are using a custom Deepchecks check identifier that is "
f"not listed in the `{str(cls)}` enum type. This could lead "
f"to unexpected behavior."
)
return check_class
@property
def check_class(self) -> Type[BaseCheck]:
"""Convert the enum value to a valid Deepchecks check class.
Returns:
The Deepchecks check class associated with the enum value.
"""
return self.get_check_class(self.value)
visualizers
special
Deepchecks visualizer.
deepchecks_visualizer
Implementation of the Deepchecks visualizer.
DeepchecksVisualizer (BaseStepVisualizer)
The implementation of a Deepchecks Visualizer.
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
class DeepchecksVisualizer(BaseStepVisualizer):
"""The implementation of a Deepchecks Visualizer."""
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments (unused).
**kwargs: Additional keyword arguments (unused).
"""
for artifact_view in object.outputs.values():
# filter out anything but data analysis artifacts
if artifact_view.type == DataAnalysisArtifact.__name__:
artifact = artifact_view.read()
self.generate_report(artifact)
def generate_report(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Generate a Deepchecks Report.
Args:
result: A SuiteResult.
"""
print(result)
if Environment.in_notebook():
result.show()
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
result.save_as_html(f)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_report(self, result)
Generate a Deepchecks Report.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A SuiteResult. |
required |
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
def generate_report(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Generate a Deepchecks Report.
Args:
result: A SuiteResult.
"""
print(result)
if Environment.in_notebook():
result.show()
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
result.save_as_html(f)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)
Method to visualize components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
*args |
Any |
Additional arguments (unused). |
() |
**kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments (unused).
**kwargs: Additional keyword arguments (unused).
"""
for artifact_view in object.outputs.values():
# filter out anything but data analysis artifacts
if artifact_view.type == DataAnalysisArtifact.__name__:
artifact = artifact_view.read()
self.generate_report(artifact)
evidently
special
Initialization of the Evidently integration.
The Evidently integration provides a way to monitor your models in production. It includes a way to detect data drift and different kinds of model performance issues.
The results of Evidently calculations can either be exported as an interactive dashboard (visualized as an html file or in your Jupyter notebook), or as a JSON file.
EvidentlyIntegration (Integration)
Evidently integration for ZenML.
Source code in zenml/integrations/evidently/__init__.py
class EvidentlyIntegration(Integration):
"""[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML."""
NAME = EVIDENTLY
REQUIREMENTS = ["evidently==0.1.52dev0"]
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.evidently import materializers # noqa
from zenml.integrations.evidently import visualizers # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=EVIDENTLY_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.evidently.data_validators.EvidentlyDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
activate()
staticmethod
Activate the Deepchecks integration.
Source code in zenml/integrations/evidently/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.evidently import materializers # noqa
from zenml.integrations.evidently import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Great Expectations integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/evidently/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=EVIDENTLY_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.evidently.data_validators.EvidentlyDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
data_validators
special
Initialization of the Evidently data validator for ZenML.
evidently_data_validator
Implementation of the Evidently data validator.
EvidentlyDataValidator (BaseDataValidator)
pydantic-model
Evidently data validator stack component.
Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
class EvidentlyDataValidator(BaseDataValidator):
"""Evidently data validator stack component."""
# Class Configuration
FLAVOR: ClassVar[str] = EVIDENTLY_DATA_VALIDATOR_FLAVOR
NAME: ClassVar[str] = "Evidently"
@classmethod
def _unpack_options(
cls, option_list: Sequence[Tuple[str, Dict[str, Any]]]
) -> Sequence[Any]:
"""Unpack Evidently options.
Implements de-serialization for [Evidently options](https://docs.evidentlyai.com/user-guide/customization)
that can be passed as constructor arguments when creating Profile and
Dashboard objects. The convention used is that each item in the list
consists of two elements:
* a string containing the full class path of a `dataclass` based
class with Evidently options
* a dictionary with kwargs used as parameters for the option instance
Example:
```python
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
```
This is the same as saying:
```python
from evidently.options import ColorOptions
color_scheme = ColorOptions()
color_scheme.primary_color = "#5a86ad"
color_scheme.fill_color = "#fff4f2"
color_scheme.zero_line_color = "#016795"
color_scheme.current_data_color = "#c292a1"
color_scheme.reference_data_color = "#017b92"
```
Args:
option_list: list of packed Evidently options
Returns:
A list of unpacked Evidently options
Raises:
ValueError: if one of the passed Evidently class paths cannot be
resolved to an actual class.
"""
options = []
for option_clspath, option_args in option_list:
try:
option_cls = load_source_path_class(option_clspath)
except AttributeError:
raise ValueError(
f"Could not map the `{option_clspath}` Evidently option "
f"class path to a valid class."
)
option = option_cls(**option_args)
options.append(option)
return options
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
column_mapping: Optional[ColumnMapping] = None,
verbose_level: int = 1,
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
**kwargs: Any,
) -> Tuple[Profile, Dashboard]:
"""Analyze a dataset and generate a data profile with Evidently.
The method takes in an optional list of Evidently options to be passed
to the profile constructor (`profile_options`) and the dashboard
constructor (`dashboard_options`). Each element in the list must be
composed of two items: the first is a full class path of an Evidently
option `dataclass`, the second is a dictionary of kwargs with the actual
option parameters, e.g.:
```python
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
```
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of Evidently
data profiles to be generated.
column_mapping: Properties of the DataFrame columns used
verbose_level: Level of verbosity for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor.
dashboard_options: Optional list of options to pass to the
dashboard constructor.
**kwargs: Extra keyword arguments (unused).
Returns:
The Evidently Profile and Dashboard objects corresponding to the set
of generated profiles.
"""
sections, tabs = get_profile_sections_and_tabs(
profile_list, verbose_level
)
unpacked_profile_options = self._unpack_options(profile_options)
unpacked_dashboard_options = self._unpack_options(dashboard_options)
dashboard = Dashboard(tabs=tabs, options=unpacked_dashboard_options)
dashboard.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
profile = Profile(sections=sections, options=unpacked_profile_options)
profile.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
return profile, dashboard
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, column_mapping=None, verbose_level=1, profile_options=[], dashboard_options=[], **kwargs)
Analyze a dataset and generate a data profile with Evidently.
The method takes in an optional list of Evidently options to be passed
to the profile constructor (profile_options
) and the dashboard
constructor (dashboard_options
). Each element in the list must be
composed of two items: the first is a full class path of an Evidently
option dataclass
, the second is a dictionary of kwargs with the actual
option parameters, e.g.:
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
Target dataset to be profiled. |
required |
comparison_dataset |
Optional[pandas.core.frame.DataFrame] |
Optional dataset to be used for data profiles that require a baseline for comparison (e.g data drift profiles). |
None |
profile_list |
Optional[Sequence[str]] |
Optional list identifying the categories of Evidently data profiles to be generated. |
None |
column_mapping |
Optional[evidently.pipeline.column_mapping.ColumnMapping] |
Properties of the DataFrame columns used |
None |
verbose_level |
int |
Level of verbosity for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
1 |
profile_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the profile constructor. |
[] |
dashboard_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the dashboard constructor. |
[] |
**kwargs |
Any |
Extra keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
Tuple[evidently.model_profile.model_profile.Profile, evidently.dashboard.dashboard.Dashboard] |
The Evidently Profile and Dashboard objects corresponding to the set of generated profiles. |
Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
column_mapping: Optional[ColumnMapping] = None,
verbose_level: int = 1,
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
**kwargs: Any,
) -> Tuple[Profile, Dashboard]:
"""Analyze a dataset and generate a data profile with Evidently.
The method takes in an optional list of Evidently options to be passed
to the profile constructor (`profile_options`) and the dashboard
constructor (`dashboard_options`). Each element in the list must be
composed of two items: the first is a full class path of an Evidently
option `dataclass`, the second is a dictionary of kwargs with the actual
option parameters, e.g.:
```python
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
```
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of Evidently
data profiles to be generated.
column_mapping: Properties of the DataFrame columns used
verbose_level: Level of verbosity for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor.
dashboard_options: Optional list of options to pass to the
dashboard constructor.
**kwargs: Extra keyword arguments (unused).
Returns:
The Evidently Profile and Dashboard objects corresponding to the set
of generated profiles.
"""
sections, tabs = get_profile_sections_and_tabs(
profile_list, verbose_level
)
unpacked_profile_options = self._unpack_options(profile_options)
unpacked_dashboard_options = self._unpack_options(dashboard_options)
dashboard = Dashboard(tabs=tabs, options=unpacked_dashboard_options)
dashboard.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
profile = Profile(sections=sections, options=unpacked_profile_options)
profile.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
return profile, dashboard
get_profile_sections_and_tabs(profile_list, verbose_level=1)
Get the profile sections and dashboard tabs for a profile list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile_list |
Optional[Sequence[str]] |
List of identifiers for Evidently profiles. |
required |
verbose_level |
int |
Verbosity level for the rendered dashboard. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
1 |
Returns:
Type | Description |
---|---|
Tuple[List[evidently.model_profile.sections.base_profile_section.ProfileSection], List[evidently.dashboard.tabs.base_tab.Tab]] |
A tuple of two lists of profile sections and tabs. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the profile_section is not supported. |
Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
def get_profile_sections_and_tabs(
profile_list: Optional[Sequence[str]],
verbose_level: int = 1,
) -> Tuple[List[ProfileSection], List[Tab]]:
"""Get the profile sections and dashboard tabs for a profile list.
Args:
profile_list: List of identifiers for Evidently profiles.
verbose_level: Verbosity level for the rendered dashboard. Use
0 for a brief dashboard, 1 for a detailed dashboard.
Returns:
A tuple of two lists of profile sections and tabs.
Raises:
ValueError: if the profile_section is not supported.
"""
profile_list = profile_list or list(profile_mapper.keys())
try:
return (
[profile_mapper[profile]() for profile in profile_list],
[
dashboard_mapper[profile](verbose_level=verbose_level)
for profile in profile_list
],
)
except KeyError as e:
nl = "\n"
raise ValueError(
f"Invalid profile sections: {profile_list} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
) from e
materializers
special
Evidently materializers.
evidently_profile_materializer
Implementation of Evidently profile materializer.
EvidentlyProfileMaterializer (BaseMaterializer)
Materializer to read data to and from an Evidently Profile.
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
class EvidentlyProfileMaterializer(BaseMaterializer):
"""Materializer to read data to and from an Evidently Profile."""
ASSOCIATED_TYPES = (Profile,)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
def handle_input(self, data_type: Type[Any]) -> Profile:
"""Reads an Evidently Profile object from a json file.
Args:
data_type: The type of the data to read.
Returns:
The Evidently Profile
Raises:
TypeError: if the json file contains an invalid data type.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
contents = yaml_utils.read_json(filepath)
if type(contents) != dict:
raise TypeError(
f"Contents {contents} was type {type(contents)} but expected "
f"dictionary"
)
section_types = contents.pop("section_types", [])
sections = []
for section_type in section_types:
section_cls = import_class_by_path(section_type)
section = section_cls()
section._result = contents[section.part_id()]
sections.append(section)
return Profile(sections=sections)
def handle_return(self, data: Profile) -> None:
"""Serialize an Evidently Profile to a json file.
Args:
data: The Evidently Profile to be serialized.
"""
super().handle_return(data)
contents = data.object()
# include the list of profile sections in the serialized dictionary,
# so we'll be able to re-create them during de-serialization
contents["section_types"] = [
resolve_class(stage.__class__) for stage in data.stages
]
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
yaml_utils.write_json(filepath, contents, encoder=NumpyEncoder)
handle_input(self, data_type)
Reads an Evidently Profile object from a json file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Profile |
The Evidently Profile |
Exceptions:
Type | Description |
---|---|
TypeError |
if the json file contains an invalid data type. |
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
def handle_input(self, data_type: Type[Any]) -> Profile:
"""Reads an Evidently Profile object from a json file.
Args:
data_type: The type of the data to read.
Returns:
The Evidently Profile
Raises:
TypeError: if the json file contains an invalid data type.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
contents = yaml_utils.read_json(filepath)
if type(contents) != dict:
raise TypeError(
f"Contents {contents} was type {type(contents)} but expected "
f"dictionary"
)
section_types = contents.pop("section_types", [])
sections = []
for section_type in section_types:
section_cls = import_class_by_path(section_type)
section = section_cls()
section._result = contents[section.part_id()]
sections.append(section)
return Profile(sections=sections)
handle_return(self, data)
Serialize an Evidently Profile to a json file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Profile |
The Evidently Profile to be serialized. |
required |
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
def handle_return(self, data: Profile) -> None:
"""Serialize an Evidently Profile to a json file.
Args:
data: The Evidently Profile to be serialized.
"""
super().handle_return(data)
contents = data.object()
# include the list of profile sections in the serialized dictionary,
# so we'll be able to re-create them during de-serialization
contents["section_types"] = [
resolve_class(stage.__class__) for stage in data.stages
]
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
yaml_utils.write_json(filepath, contents, encoder=NumpyEncoder)
steps
special
Initialization of the Evidently Standard Steps.
evidently_profile
Implementation of the Evidently Profile Step.
EvidentlyColumnMapping (BaseModel)
pydantic-model
Column mapping configuration for Evidently.
This class is a 1-to-1 serializable analogue of Evidently's ColumnMapping data type that can be used as a step configuration field (see https://docs.evidentlyai.com/features/dashboards/column_mapping).
Attributes:
Name | Type | Description |
---|---|---|
target |
Optional[str] |
target column |
prediction |
Union[str, Sequence[str]] |
target column |
datetime |
Optional[str] |
datetime column |
id |
Optional[str] |
id column |
numerical_features |
Optional[List[str]] |
numerical features |
categorical_features |
Optional[List[str]] |
categorical features |
datetime_features |
Optional[List[str]] |
datetime features |
target_names |
Optional[List[str]] |
target column names |
task |
Optional[Literal['classification', 'regression']] |
model task (regression or classification) |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyColumnMapping(BaseModel):
"""Column mapping configuration for Evidently.
This class is a 1-to-1 serializable analogue of Evidently's
ColumnMapping data type that can be used as a step configuration field
(see https://docs.evidentlyai.com/features/dashboards/column_mapping).
Attributes:
target: target column
prediction: target column
datetime: datetime column
id: id column
numerical_features: numerical features
categorical_features: categorical features
datetime_features: datetime features
target_names: target column names
task: model task (regression or classification)
"""
target: Optional[str] = None
prediction: Optional[Union[str, Sequence[str]]] = None
datetime: Optional[str] = None
id: Optional[str] = None
numerical_features: Optional[List[str]] = None
categorical_features: Optional[List[str]] = None
datetime_features: Optional[List[str]] = None
target_names: Optional[List[str]] = None
task: Optional[Literal["classification", "regression"]] = None
def to_evidently_column_mapping(self) -> ColumnMapping:
"""Convert this Pydantic object to an Evidently ColumnMapping object.
Returns:
An Evidently column mapping converted from this Pydantic object.
"""
column_mapping = ColumnMapping()
# preserve the Evidently defaults where possible
column_mapping.target = self.target or column_mapping.target
column_mapping.prediction = self.prediction or column_mapping.prediction
column_mapping.datetime = self.datetime or column_mapping.datetime
column_mapping.id = self.id or column_mapping.id
column_mapping.numerical_features = (
self.numerical_features or column_mapping.numerical_features
)
column_mapping.datetime_features = (
self.datetime_features or column_mapping.datetime_features
)
column_mapping.target_names = (
self.target_names or column_mapping.target_names
)
column_mapping.task = self.task or column_mapping.task
return column_mapping
to_evidently_column_mapping(self)
Convert this Pydantic object to an Evidently ColumnMapping object.
Returns:
Type | Description |
---|---|
ColumnMapping |
An Evidently column mapping converted from this Pydantic object. |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def to_evidently_column_mapping(self) -> ColumnMapping:
"""Convert this Pydantic object to an Evidently ColumnMapping object.
Returns:
An Evidently column mapping converted from this Pydantic object.
"""
column_mapping = ColumnMapping()
# preserve the Evidently defaults where possible
column_mapping.target = self.target or column_mapping.target
column_mapping.prediction = self.prediction or column_mapping.prediction
column_mapping.datetime = self.datetime or column_mapping.datetime
column_mapping.id = self.id or column_mapping.id
column_mapping.numerical_features = (
self.numerical_features or column_mapping.numerical_features
)
column_mapping.datetime_features = (
self.datetime_features or column_mapping.datetime_features
)
column_mapping.target_names = (
self.target_names or column_mapping.target_names
)
column_mapping.task = self.task or column_mapping.task
return column_mapping
EvidentlyProfileConfig (BaseDriftDetectionConfig)
pydantic-model
Config class for Evidently profile steps.
Attributes:
Name | Type | Description |
---|---|---|
column_mapping |
Optional[zenml.integrations.evidently.steps.evidently_profile.EvidentlyColumnMapping] |
properties of the DataFrame columns used |
profile_sections |
Optional[Sequence[str]] |
a list identifying the Evidently profile sections to be used. The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance" |
verbose_level |
int |
Verbosity level for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
profile_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
profile constructor. See |
dashboard_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
dashboard constructor. See |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
"""Config class for Evidently profile steps.
Attributes:
column_mapping: properties of the DataFrame columns used
profile_sections: a list identifying the Evidently profile sections to be
used. The following are valid options supported by Evidently:
- "datadrift"
- "categoricaltargetdrift"
- "numericaltargetdrift"
- "classificationmodelperformance"
- "regressionmodelperformance"
- "probabilisticmodelperformance"
verbose_level: Verbosity level for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor. See `EvidentlyDataValidator._unpack_options`.
dashboard_options: Optional list of options to pass to the
dashboard constructor. See `EvidentlyDataValidator._unpack_options`.
"""
column_mapping: Optional[EvidentlyColumnMapping] = None
profile_sections: Optional[Sequence[str]] = None
verbose_level: int = 1
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
EvidentlyProfileStep (BaseDriftDetectionStep)
Step implementation implementing an Evidently Profile Step.
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileStep(BaseDriftDetectionStep):
"""Step implementation implementing an Evidently Profile Step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
comparison_dataset: pd.DataFrame,
config: EvidentlyProfileConfig,
) -> Output( # type:ignore[valid-type]
profile=Profile, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection step.
Args:
reference_dataset: a Pandas DataFrame
comparison_dataset: a Pandas DataFrame of new data you wish to
compare against the reference data
config: the configuration for the step
Returns:
profile: Evidently Profile generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""
data_validator = cast(
EvidentlyDataValidator,
EvidentlyDataValidator.get_active_data_validator(),
)
column_mapping = None
if config.column_mapping:
column_mapping = config.column_mapping.to_evidently_column_mapping()
profile, dashboard = data_validator.data_profiling(
dataset=reference_dataset,
comparison_dataset=comparison_dataset,
profile_list=config.profile_sections,
column_mapping=column_mapping,
verbose_level=config.verbose_level,
profile_options=config.profile_options,
dashboard_options=config.dashboard_options,
)
return [profile, dashboard.html()]
CONFIG_CLASS (BaseDriftDetectionConfig)
pydantic-model
Config class for Evidently profile steps.
Attributes:
Name | Type | Description |
---|---|---|
column_mapping |
Optional[zenml.integrations.evidently.steps.evidently_profile.EvidentlyColumnMapping] |
properties of the DataFrame columns used |
profile_sections |
Optional[Sequence[str]] |
a list identifying the Evidently profile sections to be used. The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance" |
verbose_level |
int |
Verbosity level for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
profile_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
profile constructor. See |
dashboard_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
dashboard constructor. See |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
"""Config class for Evidently profile steps.
Attributes:
column_mapping: properties of the DataFrame columns used
profile_sections: a list identifying the Evidently profile sections to be
used. The following are valid options supported by Evidently:
- "datadrift"
- "categoricaltargetdrift"
- "numericaltargetdrift"
- "classificationmodelperformance"
- "regressionmodelperformance"
- "probabilisticmodelperformance"
verbose_level: Verbosity level for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor. See `EvidentlyDataValidator._unpack_options`.
dashboard_options: Optional list of options to pass to the
dashboard constructor. See `EvidentlyDataValidator._unpack_options`.
"""
column_mapping: Optional[EvidentlyColumnMapping] = None
profile_sections: Optional[Sequence[str]] = None
verbose_level: int = 1
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
entrypoint(self, reference_dataset, comparison_dataset, config)
Main entrypoint for the Evidently categorical target drift detection step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
a Pandas DataFrame |
required |
comparison_dataset |
DataFrame |
a Pandas DataFrame of new data you wish to compare against the reference data |
required |
config |
EvidentlyProfileConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
profile |
Evidently Profile generated for the data drift dashboard: HTML report extracted from an Evidently Dashboard generated for the data drift |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
comparison_dataset: pd.DataFrame,
config: EvidentlyProfileConfig,
) -> Output( # type:ignore[valid-type]
profile=Profile, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection step.
Args:
reference_dataset: a Pandas DataFrame
comparison_dataset: a Pandas DataFrame of new data you wish to
compare against the reference data
config: the configuration for the step
Returns:
profile: Evidently Profile generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""
data_validator = cast(
EvidentlyDataValidator,
EvidentlyDataValidator.get_active_data_validator(),
)
column_mapping = None
if config.column_mapping:
column_mapping = config.column_mapping.to_evidently_column_mapping()
profile, dashboard = data_validator.data_profiling(
dataset=reference_dataset,
comparison_dataset=comparison_dataset,
profile_list=config.profile_sections,
column_mapping=column_mapping,
verbose_level=config.verbose_level,
profile_options=config.profile_options,
dashboard_options=config.dashboard_options,
)
return [profile, dashboard.html()]
evidently_profile_step(step_name, config)
Shortcut function to create a new instance of the EvidentlyProfileConfig step.
The returned EvidentlyProfileStep can be used in a pipeline to run model drift analyses on two input pd.DataFrame datasets and return the results as an Evidently profile object and a rendered dashboard object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
EvidentlyProfileConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a EvidentlyProfileStep step instance |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def evidently_profile_step(
step_name: str,
config: EvidentlyProfileConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the EvidentlyProfileConfig step.
The returned EvidentlyProfileStep can be used in a pipeline to
run model drift analyses on two input pd.DataFrame datasets and return the
results as an Evidently profile object and a rendered dashboard object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a EvidentlyProfileStep step instance
"""
return clone_step(EvidentlyProfileStep, step_name)(config=config)
visualizers
special
Initialization for Evidently visualizer.
evidently_visualizer
Implementation of the Evidently visualizer.
EvidentlyVisualizer (BaseStepVisualizer)
The implementation of an Evidently Visualizer.
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
class EvidentlyVisualizer(BaseStepVisualizer):
"""The implementation of an Evidently Visualizer."""
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but data artifacts
if (
artifact_view.type == DataArtifact.__name__
and artifact_view.data_type == "builtins.str"
):
artifact = artifact_view.read()
self.generate_facet(artifact)
def generate_facet(self, html_: str) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
"""
if Environment.in_notebook() or Environment.in_google_colab():
from IPython.core.display import HTML, display
display(HTML(html_))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_facet(self, html_)
Generate a Facet Overview.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
html_ |
str |
HTML represented as a string. |
required |
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
def generate_facet(self, html_: str) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
"""
if Environment.in_notebook() or Environment.in_google_colab():
from IPython.core.display import HTML, display
display(HTML(html_))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)
Method to visualize components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but data artifacts
if (
artifact_view.type == DataArtifact.__name__
and artifact_view.data_type == "builtins.str"
):
artifact = artifact_view.read()
self.generate_facet(artifact)
facets
special
Facets integration for ZenML.
The Facets integration provides a simple way to visualize post-execution objects
like PipelineView
, PipelineRunView
and StepView
. These objects can be
extended using the BaseVisualization
class. This integration requires
facets-overview
be installed in your Python environment.
FacetsIntegration (Integration)
Definition of Facet integration for ZenML.
Source code in zenml/integrations/facets/__init__.py
class FacetsIntegration(Integration):
"""Definition of [Facet](https://pair-code.github.io/facets/) integration for ZenML."""
NAME = FACETS
REQUIREMENTS = ["facets-overview>=1.0.0", "IPython"]
visualizers
special
Intitialization of the Facet Visualizer.
facet_statistics_visualizer
Implementation of the Facet Statistics Visualizer.
FacetStatisticsVisualizer (BaseStepVisualizer)
The base implementation of a ZenML Visualizer.
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
class FacetStatisticsVisualizer(BaseStepVisualizer):
"""The base implementation of a ZenML Visualizer."""
@abstractmethod
def visualize(
self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
magic: Whether to render in a Jupyter notebook or not.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
datasets = []
for output_name, artifact_view in object.outputs.items():
df = artifact_view.read()
if type(df) is not pd.DataFrame:
logger.warning(
"`%s` is not a pd.DataFrame. You can only visualize "
"statistics of steps that output pandas DataFrames. "
"Skipping this output.." % output_name
)
else:
datasets.append({"name": output_name, "table": df})
h = self.generate_html(datasets)
self.generate_facet(h, magic)
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
"""Generates html for facet.
Args:
datasets: List of dicts of DataFrames to be visualized as stats.
Returns:
HTML template with proto string embedded.
"""
proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
datasets
)
protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")
template = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"stats.html",
)
html_template = io_utils.read_file_contents_as_string(template)
html_ = html_template.replace("protostr", protostr)
return html_
def generate_facet(self, html_: str, magic: bool = False) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
magic: Whether to magically materialize facet in a notebook.
Raises:
EnvironmentError: If magic is True and not in a notebook.
"""
if magic:
if not (Environment.in_notebook() or Environment.in_google_colab()):
raise EnvironmentError(
"The magic functions are only usable in a Jupyter notebook."
)
display(HTML(html_))
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
io_utils.write_file_contents_as_string(f.name, html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_facet(self, html_, magic=False)
Generate a Facet Overview.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
html_ |
str |
HTML represented as a string. |
required |
magic |
bool |
Whether to magically materialize facet in a notebook. |
False |
Exceptions:
Type | Description |
---|---|
EnvironmentError |
If magic is True and not in a notebook. |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_facet(self, html_: str, magic: bool = False) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
magic: Whether to magically materialize facet in a notebook.
Raises:
EnvironmentError: If magic is True and not in a notebook.
"""
if magic:
if not (Environment.in_notebook() or Environment.in_google_colab()):
raise EnvironmentError(
"The magic functions are only usable in a Jupyter notebook."
)
display(HTML(html_))
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
io_utils.write_file_contents_as_string(f.name, html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_html(self, datasets)
Generates html for facet.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets |
List[Dict[str, pandas.core.frame.DataFrame]] |
List of dicts of DataFrames to be visualized as stats. |
required |
Returns:
Type | Description |
---|---|
str |
HTML template with proto string embedded. |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
"""Generates html for facet.
Args:
datasets: List of dicts of DataFrames to be visualized as stats.
Returns:
HTML template with proto string embedded.
"""
proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
datasets
)
protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")
template = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"stats.html",
)
html_template = io_utils.read_file_contents_as_string(template)
html_ = html_template.replace("protostr", protostr)
return html_
visualize(self, object, magic=False, *args, **kwargs)
Method to visualize components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
magic |
bool |
Whether to render in a Jupyter notebook or not. |
False |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
@abstractmethod
def visualize(
self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
magic: Whether to render in a Jupyter notebook or not.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
datasets = []
for output_name, artifact_view in object.outputs.items():
df = artifact_view.read()
if type(df) is not pd.DataFrame:
logger.warning(
"`%s` is not a pd.DataFrame. You can only visualize "
"statistics of steps that output pandas DataFrames. "
"Skipping this output.." % output_name
)
else:
datasets.append({"name": output_name, "table": df})
h = self.generate_html(datasets)
self.generate_facet(h, magic)
feast
special
Initialization for Feast integration.
The Feast integration offers a way to connect to a Feast Feature Store. ZenML implements a dedicated stack component that you can access as part of your ZenML steps in the usual ways.
FeastIntegration (Integration)
Definition of Feast integration for ZenML.
Source code in zenml/integrations/feast/__init__.py
class FeastIntegration(Integration):
"""Definition of Feast integration for ZenML."""
NAME = FEAST
REQUIREMENTS = ["feast[redis]>=0.19.4", "redis-server"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Feast integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=FEAST_FEATURE_STORE_FLAVOR,
source="zenml.integrations.feast.feature_stores.FeastFeatureStore",
type=StackComponentType.FEATURE_STORE,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the Feast integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/feast/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Feast integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=FEAST_FEATURE_STORE_FLAVOR,
source="zenml.integrations.feast.feature_stores.FeastFeatureStore",
type=StackComponentType.FEATURE_STORE,
integration=cls.NAME,
)
]
feature_stores
special
Feast Feature Store integration for ZenML.
Feature stores allow data teams to serve data via an offline store and an online low-latency store where data is kept in sync between the two. It also offers a centralized registry where features (and feature schemas) are stored for use within a team or wider organization. Feature stores are a relatively recent addition to commonly-used machine learning stacks. Feast is a leading open-source feature store, first developed by Gojek in collaboration with Google.
feast_feature_store
Implementation of the Feast Feature Store for ZenML.
FeastFeatureStore (BaseFeatureStore)
pydantic-model
Class to interact with the Feast feature store.
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
class FeastFeatureStore(BaseFeatureStore):
"""Class to interact with the Feast feature store."""
FLAVOR: ClassVar[str] = FEAST_FEATURE_STORE_FLAVOR
online_host: str = "localhost"
online_port: int = 6379
feast_repo: str
def _validate_connection(self) -> None:
"""Validates the connection to the feature store.
Raises:
ConnectionError: If the online component (Redis) is not available.
"""
client = redis.Redis(host=self.online_host, port=self.online_port)
try:
client.ping()
except redis.exceptions.ConnectionError as e:
raise redis.exceptions.ConnectionError(
"Could not connect to feature store's online component. "
"Please make sure that Redis is running."
) from e
def get_historical_features(
self,
entity_df: Union[pd.DataFrame, str],
features: List[str],
full_feature_names: bool = False,
) -> pd.DataFrame:
"""Returns the historical features for training or batch scoring.
Args:
entity_df: The entity DataFrame or entity name.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The historical features as a Pandas DataFrame.
"""
fs = FeatureStore(repo_path=self.feast_repo)
return fs.get_historical_features(
entity_df=entity_df,
features=features,
full_feature_names=full_feature_names,
).to_df()
def get_online_features(
self,
entity_rows: List[Dict[str, Any]],
features: List[str],
full_feature_names: bool = False,
) -> Dict[str, Any]:
"""Returns the latest online feature data.
Args:
entity_rows: The entity rows to retrieve.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The latest online feature data as a dictionary.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return fs.get_online_features( # type: ignore[no-any-return]
entity_rows=entity_rows,
features=features,
full_feature_names=full_feature_names,
).to_dict()
def get_data_sources(self) -> List[str]:
"""Returns the data sources' names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The data sources' names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_data_sources()]
def get_entities(self) -> List[str]:
"""Returns the entity names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The entity names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_entities()]
def get_feature_services(self) -> List[str]:
"""Returns the feature service names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature service names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_feature_services()]
def get_feature_views(self) -> List[str]:
"""Returns the feature view names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature view names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_feature_views()]
def get_project(self) -> str:
"""Returns the project name.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The project name.
"""
fs = FeatureStore(repo_path=self.feast_repo)
return str(fs.project)
def get_registry(self) -> Registry:
"""Returns the feature store registry.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The registry.
"""
fs: FeatureStore = FeatureStore(repo_path=self.feast_repo)
return fs.registry
def get_feast_version(self) -> str:
"""Returns the version of Feast used.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The version of Feast currently being used.
"""
fs = FeatureStore(repo_path=self.feast_repo)
return str(fs.version())
get_data_sources(self)
Returns the data sources' names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The data sources' names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_data_sources(self) -> List[str]:
"""Returns the data sources' names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The data sources' names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_data_sources()]
get_entities(self)
Returns the entity names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The entity names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_entities(self) -> List[str]:
"""Returns the entity names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The entity names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_entities()]
get_feast_version(self)
Returns the version of Feast used.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
str |
The version of Feast currently being used. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feast_version(self) -> str:
"""Returns the version of Feast used.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The version of Feast currently being used.
"""
fs = FeatureStore(repo_path=self.feast_repo)
return str(fs.version())
get_feature_services(self)
Returns the feature service names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The feature service names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_services(self) -> List[str]:
"""Returns the feature service names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature service names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_feature_services()]
get_feature_views(self)
Returns the feature view names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The feature view names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_views(self) -> List[str]:
"""Returns the feature view names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature view names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return [ds.name for ds in fs.list_feature_views()]
get_historical_features(self, entity_df, features, full_feature_names=False)
Returns the historical features for training or batch scoring.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_df |
Union[pandas.core.frame.DataFrame, str] |
The entity DataFrame or entity name. |
required |
features |
List[str] |
The features to retrieve. |
required |
full_feature_names |
bool |
Whether to return the full feature names. |
False |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
DataFrame |
The historical features as a Pandas DataFrame. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_historical_features(
self,
entity_df: Union[pd.DataFrame, str],
features: List[str],
full_feature_names: bool = False,
) -> pd.DataFrame:
"""Returns the historical features for training or batch scoring.
Args:
entity_df: The entity DataFrame or entity name.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The historical features as a Pandas DataFrame.
"""
fs = FeatureStore(repo_path=self.feast_repo)
return fs.get_historical_features(
entity_df=entity_df,
features=features,
full_feature_names=full_feature_names,
).to_df()
get_online_features(self, entity_rows, features, full_feature_names=False)
Returns the latest online feature data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_rows |
List[Dict[str, Any]] |
The entity rows to retrieve. |
required |
features |
List[str] |
The features to retrieve. |
required |
full_feature_names |
bool |
Whether to return the full feature names. |
False |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The latest online feature data as a dictionary. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_online_features(
self,
entity_rows: List[Dict[str, Any]],
features: List[str],
full_feature_names: bool = False,
) -> Dict[str, Any]:
"""Returns the latest online feature data.
Args:
entity_rows: The entity rows to retrieve.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The latest online feature data as a dictionary.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.feast_repo)
return fs.get_online_features( # type: ignore[no-any-return]
entity_rows=entity_rows,
features=features,
full_feature_names=full_feature_names,
).to_dict()
get_project(self)
Returns the project name.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
str |
The project name. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_project(self) -> str:
"""Returns the project name.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The project name.
"""
fs = FeatureStore(repo_path=self.feast_repo)
return str(fs.project)
get_registry(self)
Returns the feature store registry.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
Registry |
The registry. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_registry(self) -> Registry:
"""Returns the feature store registry.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The registry.
"""
fs: FeatureStore = FeatureStore(repo_path=self.feast_repo)
return fs.registry
gcp
special
Initialization of the GCP ZenML integration.
The GCP integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores, metadata
stores, and an io
module to handle file operations on Google Cloud Storage
(GCS).
Additionally, the GCP secrets manager integration submodule provides a way to access the GCP secrets manager from within your ZenML Pipeline runs.
The Vertex AI integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.
GcpIntegration (Integration)
Definition of Google Cloud Platform integration for ZenML.
Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
"""Definition of Google Cloud Platform integration for ZenML."""
NAME = GCP
REQUIREMENTS = [
"kfp==1.8.9",
"gcsfs",
"google-cloud-secret-manager",
"google-cloud-aiplatform>=1.11.0",
]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the GCP integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=GCP_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.gcp.artifact_stores"
".GCPArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=GCP_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.gcp.secrets_manager."
"GCPSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=GCP_VERTEX_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.gcp.orchestrators"
".VertexOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
FlavorWrapper(
name=GCP_VERTEX_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.gcp.step_operators"
".VertexStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the GCP integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/gcp/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the GCP integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=GCP_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.gcp.artifact_stores"
".GCPArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=GCP_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.gcp.secrets_manager."
"GCPSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=GCP_VERTEX_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.gcp.orchestrators"
".VertexOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
FlavorWrapper(
name=GCP_VERTEX_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.gcp.step_operators"
".VertexStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
artifact_stores
special
Initialization of the GCP Artifact Store.
gcp_artifact_store
Implementation of the GCP Artifact Store.
GCPArtifactStore (BaseArtifactStore, AuthenticationMixin)
pydantic-model
Artifact Store for Google Cloud Storage based artifacts.
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for Google Cloud Storage based artifacts."""
_filesystem: Optional[gcsfs.GCSFileSystem] = None
# Class Configuration
FLAVOR: ClassVar[str] = GCP_ARTIFACT_STORE_FLAVOR
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX}
@property
def filesystem(self) -> gcsfs.GCSFileSystem:
"""The gcsfs filesystem to access this artifact store.
Returns:
The gcsfs filesystem to access this artifact store.
"""
if not self._filesystem:
secret = self.get_authentication_secret(
expected_schema_type=GCPSecretSchema
)
token = secret.get_credential_dict() if secret else None
self._filesystem = gcsfs.GCSFileSystem(token=token)
return self._filesystem
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object that can be used to read or write to the file.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [
f"{GCP_PATH_PREFIX}{path}"
for path in self.filesystem.glob(path=pattern)
]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path of the directory to list.
Returns:
A list of paths of files in the directory.
"""
path_without_prefix = convert_to_str(path)
if path_without_prefix.startswith(GCP_PATH_PREFIX):
path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by GCP.
Args:
file_dict: A file info dict returned by the GCP filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path_without_prefix) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: the path to get stat info for.
Returns:
A dictionary with the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
filesystem: GCSFileSystem
property
readonly
The gcsfs filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
GCSFileSystem |
The gcsfs filesystem to access this artifact store. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [
f"{GCP_PATH_PREFIX}{path}"
for path in self.filesystem.glob(path=pattern)
]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths of files in the directory. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path of the directory to list.
Returns:
A list of paths of files in the directory.
"""
path_without_prefix = convert_to_str(path)
if path_without_prefix.startswith(GCP_PATH_PREFIX):
path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by GCP.
Args:
file_dict: A file info dict returned by the GCP filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path_without_prefix) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to create. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to create. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object that can be used to read or write to the file. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object that can be used to read or write to the file.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the file to remove. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
the path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the stat info. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: the path to get stat info for.
Returns:
A dictionary with the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
constants
Constants for the VertexAI integration.
google_credentials_mixin
Implementation of the Google credentials mixin.
GoogleCredentialsMixin (BaseModel)
pydantic-model
Mixin for Google Cloud Platform credentials.
Attributes:
Name | Type | Description |
---|---|---|
service_account_path |
Optional[str] |
path to the service account credentials file to be used for authentication. If not provided, the default credentials will be used. |
Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsMixin(BaseModel):
"""Mixin for Google Cloud Platform credentials.
Attributes:
service_account_path: path to the service account credentials file to be
used for authentication. If not provided, the default credentials
will be used.
"""
service_account_path: Optional[str] = None
def _get_authentication(self) -> Tuple["Credentials", str]:
"""Get GCP credentials and the project ID associated with the credentials.
If `service_account_path` is provided, then the credentials will be
loaded from the file at that path. Otherwise, the default credentials
will be used.
Returns:
A tuple containing the credentials and the project ID associated to
the credentials.
"""
if self.service_account_path:
credentials, project_id = load_credentials_from_file(
self.service_account_path
)
else:
credentials, project_id = default()
return credentials, project_id
orchestrators
special
Initialization for the VertexAI orchestrator.
vertex_entrypoint_configuration
Implementation of the VertexAI entrypoint configuration.
VertexEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on Vertex AI Pipelines.
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
class VertexEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on Vertex AI Pipelines."""
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""Vertex AI Pipelines specific entrypoint options.
The argument `VERTEX_JOB_ID_OPTION` allows to specify the job id of the
Vertex AI Pipeline and get it in the execution of the step, via the `get_run_name`
method.
Returns:
The set of custom entrypoint options.
"""
return {VERTEX_JOB_ID_OPTION}
@classmethod
def get_custom_entrypoint_arguments(
cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
"""Sets the value for the `VERTEX_JOB_ID_OPTION` argument.
Args:
step: The step to be executed.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
A list of arguments for the entrypoint.
"""
return [f"--{VERTEX_JOB_ID_OPTION}", kwargs[VERTEX_JOB_ID_OPTION]]
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the Vertex AI Pipeline job id.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Vertex AI Pipeline job id.
"""
job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
return job_id
get_custom_entrypoint_arguments(step, *args, **kwargs)
classmethod
Sets the value for the VERTEX_JOB_ID_OPTION
argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
The step to be executed. |
required |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
A list of arguments for the entrypoint. |
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
"""Sets the value for the `VERTEX_JOB_ID_OPTION` argument.
Args:
step: The step to be executed.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
A list of arguments for the entrypoint.
"""
return [f"--{VERTEX_JOB_ID_OPTION}", kwargs[VERTEX_JOB_ID_OPTION]]
get_custom_entrypoint_options()
classmethod
Vertex AI Pipelines specific entrypoint options.
The argument VERTEX_JOB_ID_OPTION
allows to specify the job id of the
Vertex AI Pipeline and get it in the execution of the step, via the get_run_name
method.
Returns:
Type | Description |
---|---|
Set[str] |
The set of custom entrypoint options. |
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""Vertex AI Pipelines specific entrypoint options.
The argument `VERTEX_JOB_ID_OPTION` allows to specify the job id of the
Vertex AI Pipeline and get it in the execution of the step, via the `get_run_name`
method.
Returns:
The set of custom entrypoint options.
"""
return {VERTEX_JOB_ID_OPTION}
get_run_name(self, pipeline_name)
Returns the Vertex AI Pipeline job id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
str |
The Vertex AI Pipeline job id. |
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the Vertex AI Pipeline job id.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Vertex AI Pipeline job id.
"""
job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
return job_id
vertex_orchestrator
Implementation of the VertexAI orchestrator.
VertexOrchestrator (BaseOrchestrator, GoogleCredentialsMixin)
pydantic-model
Orchestrator responsible for running pipelines on Vertex AI.
Attributes:
Name | Type | Description |
---|---|---|
custom_docker_base_image_name |
Optional[str] |
Name of the Docker image that should be used as the base for the image that will be used to execute each of the steps. If no custom base image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML Docker images found here: https://hub.docker.com/r/zenmldocker/zenml/ |
project |
Optional[str] |
GCP project name. If |
location |
str |
Name of GCP region where the pipeline job will be executed. Vertex AI Pipelines is available in the following regions: https://cloud.google.com/vertex-ai/docs/general/locations#feature -availability |
pipeline_root |
Optional[str] |
a Cloud Storage URI that will be used by the Vertex AI |
encryption_spec_key_name |
Optional[str] |
The Cloud KMS resource identifier of the |
customer
managed |
encryption key used to protect the job. Has the form |
|
workload_service_account |
Optional[str] |
the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If not provided, the default service account will be used. |
network |
Optional[str] |
the full name of the Compute Engine Network to which the job |
synchronous |
bool |
If |
cpu_limit |
Optional[str] |
The maximum CPU limit for this operator. This string value can be a number (integer value for number of CPUs) as string, or a number followed by "m", which means 1/1000. You can specify at most 96 CPUs. (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types) |
memory_limit |
Optional[str] |
The maximum memory limit for this operator. This string value can be a number, or a number followed by "K" (kilobyte), "M" (megabyte), or "G" (gigabyte). At most 624GB is supported. |
node_selector_constraint |
Optional[Tuple[str, str]] |
Each constraint is a key-value pair label. For the container to be eligible to run on a node, the node must have each of the constraints appeared as labels. For example a GPU type can be providing by one of the following tuples: - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100") Hint: the selected region (location) must provide the requested accelerator (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones). |
gpu_limit |
Optional[int] |
The GPU limit (positive number) for the operator. For more information about GPU resources, see: https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
class VertexOrchestrator(BaseOrchestrator, GoogleCredentialsMixin):
"""Orchestrator responsible for running pipelines on Vertex AI.
Attributes:
custom_docker_base_image_name: Name of the Docker image that should be
used as the base for the image that will be used to execute each of
the steps. If no custom base image is given, a basic image of the
active ZenML version will be used. **Note**: This image needs to
have ZenML installed, otherwise the pipeline execution will fail.
For that reason, you might want to extend the ZenML Docker images found
here: https://hub.docker.com/r/zenmldocker/zenml/
project: GCP project name. If `None`, the project will be inferred from
the environment.
location: Name of GCP region where the pipeline job will be executed.
Vertex AI Pipelines is available in the following regions:
https://cloud.google.com/vertex-ai/docs/general/locations#feature
-availability
pipeline_root: a Cloud Storage URI that will be used by the Vertex AI
Pipelines.
If not provided but the artifact store in the stack used to execute
the pipeline is a
`zenml.integrations.gcp.artifact_stores.GCPArtifactStore`,
then a subdirectory of the artifact store will be used.
encryption_spec_key_name: The Cloud KMS resource identifier of the
customer
managed encryption key used to protect the job. Has the form:
`projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY>`
. The key needs to be in the same region as where the compute
resource is created.
workload_service_account: the service account for workload run-as
account. Users submitting jobs must have act-as permission on this
run-as account.
If not provided, the default service account will be used.
network: the full name of the Compute Engine Network to which the job
should
be peered. For example, `projects/12345/global/networks/myVPC`
If not provided, the job will not be peered with any network.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on Vertex AI Pipelines
service.
cpu_limit: The maximum CPU limit for this operator. This string value
can be a number (integer value for number of CPUs) as string,
or a number followed by "m", which means 1/1000. You can specify
at most 96 CPUs.
(see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)
memory_limit: The maximum memory limit for this operator. This string
value can be a number, or a number followed by "K" (kilobyte),
"M" (megabyte), or "G" (gigabyte). At most 624GB is supported.
node_selector_constraint: Each constraint is a key-value pair label.
For the container to be eligible to run on a node, the node must have
each of the constraints appeared as labels.
For example a GPU type can be providing by one of the following tuples:
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100")
Hint: the selected region (location) must provide the requested accelerator
(see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).
gpu_limit: The GPU limit (positive number) for the operator.
For more information about GPU resources, see:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
"""
custom_docker_base_image_name: Optional[str] = None
project: Optional[str] = None
location: str
pipeline_root: Optional[str] = None
labels: Dict[str, str] = {}
encryption_spec_key_name: Optional[str] = None
workload_service_account: Optional[str] = None
network: Optional[str] = None
synchronous: bool = False
cpu_limit: Optional[str] = None
memory_limit: Optional[str] = None
node_selector_constraint: Optional[Tuple[str, str]] = None
gpu_limit: Optional[int] = None
_pipeline_root: str
FLAVOR: ClassVar[str] = GCP_VERTEX_ORCHESTRATOR_FLAVOR
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Also validates that the artifact store and metadata store used are not
local.
Returns:
A StackValidator instance.
"""
def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
"""Validates that all the stack components are not local.
Args:
stack: The stack to validate.
Returns:
A tuple of (is_valid, error_message).
"""
# Validate that the container registry is not local.
container_registry = stack.container_registry
if container_registry and container_registry.is_local:
return False, (
f"The Vertex orchestrator does not support local "
f"container registries. You should replace the component '"
f"{container_registry.name}' "
f"{container_registry.TYPE.value} to a remote one."
)
# Validate that the rest of the components are not local.
for stack_comp in stack.components.values():
local_path = stack_comp.local_path
if not local_path:
continue
return False, (
f"The '{stack_comp.name}' {stack_comp.TYPE.value} is a "
f"local stack component. The Vertex AI Pipelines "
f"orchestrator requires that all the components in the "
f"stack used to execute the pipeline have to be not local, "
f"because there is no way for Vertex to connect to your "
f"local machine. You should use a flavor of "
f"{stack_comp.TYPE.value} other than '"
f"{stack_comp.FLAVOR}'."
)
# If the `pipeline_root` has not been defined in the orchestrator
# configuration, and the artifact store is not a GCP artifact store,
# then raise an error.
if (
not self.pipeline_root
and stack.artifact_store.FLAVOR != GCP_ARTIFACT_STORE_FLAVOR
):
return False, (
f"The attribute `pipeline_root` has not been set and it "
f"cannot be generated using the path of the artifact store "
f"because it is not a "
f"`zenml.integrations.gcp.artifact_store.GCPArtifactStore`."
f" To solve this issue, set the `pipeline_root` attribute "
f"manually executing the following command: "
f"`zenml orchestrator update {stack.orchestrator.name} "
f'--pipeline_root="<Cloud Storage URI>"`.'
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_stack_requirements,
)
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag.
Args:
pipeline_name: The name of the pipeline.
Returns:
The full docker image name including registry and tag.
"""
base_image_name = f"zenml-vertex:{pipeline_name}"
container_registry = Repository().active_stack.container_registry
if container_registry:
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/{base_image_name}"
return base_image_name
@property
def root_directory(self) -> str:
"""Returns path to the root directory for files for this orchestrator.
Returns:
The path to the root directory for all files concerning this
orchestrator.
"""
return os.path.join(
get_global_config_directory(), "vertex", str(self.uuid)
)
@property
def pipeline_directory(self) -> str:
"""Returns path to directory where kubeflow pipelines files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Build a Docker image for the current environment.
This uploads it to a container registry if configured.
Args:
pipeline: The pipeline to be deployed.
stack: The stack that will be used to deploy the pipeline.
runtime_configuration: The runtime configuration for the pipeline.
Raises:
RuntimeError: If the container registry is missing.
"""
from zenml.utils import docker_utils
repo = Repository()
container_registry = repo.active_stack.container_registry
if not container_registry:
raise RuntimeError("Missing container registry")
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug(
"Vertex AI Pipelines service docker container requirements %s",
requirements,
)
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
)
container_registry.push_image(image_name)
def _configure_container_resources(
self,
container_op: dsl.ContainerOp,
resource_configuration: "ResourceConfiguration",
) -> None:
"""Adds resource requirements to the container.
Args:
container_op: The kubeflow container operation to configure.
resource_configuration: The resource configuration to use for this
container.
"""
# Set optional CPU, RAM and GPU constraints for the pipeline
cpu_limit = resource_configuration.cpu_count or self.cpu_limit
if cpu_limit is not None:
container_op = container_op.set_cpu_limit(str(cpu_limit))
memory_limit = (
resource_configuration.memory[:-1]
if resource_configuration.memory
else self.memory_limit
)
if memory_limit is not None:
container_op = container_op.set_memory_limit(memory_limit)
if self.node_selector_constraint is not None:
container_op = container_op.add_node_selector_constraint(
label_name=self.node_selector_constraint[0],
value=self.node_selector_constraint[1],
)
gpu_limit = resource_configuration.gpu_count or self.gpu_limit
if gpu_limit is not None:
container_op = container_op.set_gpu_limit(gpu_limit)
def prepare_or_run_pipeline(
self,
sorted_steps: List["BaseStep"],
pipeline: "BasePipeline",
pb2_pipeline: "Pb2Pipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates a KFP JSON pipeline.
# noqa: DAR402
This is an intermediary representation of the pipeline which is then
deployed to Vertex AI Pipelines service.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()` method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`). The function
`kfp.components.load_component_from_text` is used to create the
`ContainerOp`, because using the `dsl.ContainerOp` class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the
intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for
execution.
Args:
sorted_steps: List of sorted steps.
pipeline: Zenml Pipeline instance.
pb2_pipeline: Protobuf Pipeline instance.
stack: The stack the pipeline was run on.
runtime_configuration: The Runtime configuration of the current run.
Raises:
ValueError: If the attribute `pipeline_root` is not set and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
# If the `pipeline_root` has not been defined in the orchestrator
# configuration,
# try to create it from the artifact store if it is a
# `GCPArtifactStore`.
if not self.pipeline_root:
artifact_store = stack.artifact_store
self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{pipeline.name}/{runtime_configuration.run_name}"
logger.info(
"The attribute `pipeline_root` has not been set in the "
"orchestrator configuration. One has been generated "
"automatically based on the path of the `GCPArtifactStore` "
"artifact store in the stack used to execute the pipeline. "
"The generated `pipeline_root` is `%s`.",
self._pipeline_root,
)
else:
self._pipeline_root = self.pipeline_root
# Build the Docker image that will be used to run the steps of the
# pipeline.
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
def _construct_kfp_pipeline() -> None:
"""Create a `ContainerOp` for each step.
This should contain the name of the Docker image and configures the
entrypoint of the Docker image to run the step.
Additionally, this gives each `ContainerOp` information about its
direct downstream steps.
If this callable is passed to the `compile()` method of
`KFPV2Compiler` all `dsl.ContainerOp` instances will be
automatically added to a singular `dsl.Pipeline` instance.
"""
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step in sorted_steps:
# The command will be needed to eventually call the python step
# within the docker container
command = VertexEntrypointConfiguration.get_entrypoint_command()
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
**{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
)
# Create the `ContainerOp` for the step. Using the
# `dsl.ContainerOp`
# class directly is deprecated when using the Kubeflow SDK v2.
container_op = kfp.components.load_component_from_text(
f"""
name: {step.name}
implementation:
container:
image: {image_name}
command: {command + arguments}"""
)()
# Set upstream tasks as a dependency of the current step
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
self._configure_container_resources(
container_op=container_op,
resource_configuration=step.resource_configuration,
)
step_name_to_container_op[step.name] = container_op
# Save the generated pipeline to a file.
assert runtime_configuration.run_name
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory,
f"{runtime_configuration.run_name}.json",
)
# Compile the pipeline using the Kubeflow SDK V2 compiler that allows
# to generate a JSON representation of the pipeline that can be later
# upload to Vertex AI Pipelines service.
logger.debug(
"Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
"to `%s`",
pipeline_file_path,
)
KFPV2Compiler().compile(
pipeline_func=_construct_kfp_pipeline,
package_path=pipeline_file_path,
pipeline_name=_clean_pipeline_name(pipeline.name),
)
# Using the Google Cloud AIPlatform client, upload and execute the
# pipeline
# on the Vertex AI Pipelines service.
self._upload_and_run_pipeline(
pipeline_name=pipeline.name,
pipeline_file_path=pipeline_file_path,
runtime_configuration=runtime_configuration,
enable_cache=pipeline.enable_cache,
)
def _upload_and_run_pipeline(
self,
pipeline_name: str,
pipeline_file_path: str,
runtime_configuration: "RuntimeConfiguration",
enable_cache: bool,
) -> None:
"""Uploads and run the pipeline on the Vertex AI Pipelines service.
Args:
pipeline_name: Name of the pipeline.
pipeline_file_path: Path of the JSON file containing the compiled
Kubeflow pipeline (compiled with Kubeflow SDK v2).
runtime_configuration: Runtime configuration of the pipeline run.
enable_cache: Whether caching is enabled for this pipeline run.
"""
# We have to replace the hyphens in the pipeline name with underscores
# and lower case the string, because the Vertex AI Pipelines service
# requires this format.
assert runtime_configuration.run_name
job_id = _clean_pipeline_name(runtime_configuration.run_name)
# Warn the user that the scheduling is not available using the Vertex
# Orchestrator
if runtime_configuration.schedule:
logger.warning(
"Pipeline scheduling configuration was provided, but Vertex "
"AI Pipelines "
"do not have capabilities for scheduling yet."
)
# Get the credentials that would be used to create the Vertex AI
# Pipelines
# job.
credentials, project_id = self._get_authentication()
if self.project and self.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this orchestrator is "
"configured to use the project `%s`.",
project_id,
self.project,
)
# If the project was set in the configuration, use it. Otherwise, use
# the project that was used to authenticate.
project_id = self.project if self.project else project_id
# Instantiate the Vertex AI Pipelines job
run = aiplatform.PipelineJob(
display_name=pipeline_name,
template_path=pipeline_file_path,
job_id=job_id,
pipeline_root=self._pipeline_root,
parameter_values=None,
enable_caching=enable_cache,
encryption_spec_key_name=self.encryption_spec_key_name,
labels=self.labels,
credentials=credentials,
project=self.project,
location=self.location,
)
logger.info(
"Submitting pipeline job with job_id `%s` to Vertex AI Pipelines "
"service.",
job_id,
)
# Submit the job to Vertex AI Pipelines service.
try:
if self.workload_service_account:
logger.info(
"The Vertex AI Pipelines job workload will be executed "
"using `%s` "
"service account.",
self.workload_service_account,
)
if self.network:
logger.info(
"The Vertex AI Pipelines job will be peered with `%s` "
"network.",
self.network,
)
run.submit(
service_account=self.workload_service_account,
network=self.network,
)
logger.info(
"View the Vertex AI Pipelines job at %s", run._dashboard_uri()
)
if self.synchronous:
logger.info(
"Waiting for the Vertex AI Pipelines job to finish..."
)
run.wait()
except google_exceptions.ClientError as e:
logger.warning(
"Failed to create the Vertex AI Pipelines job: %s", e
)
except RuntimeError as e:
logger.error(
"The Vertex AI Pipelines job execution has failed: %s", e
)
pipeline_directory: str
property
readonly
Returns path to directory where kubeflow pipelines files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Returns path to the root directory for files for this orchestrator.
Returns:
Type | Description |
---|---|
str |
The path to the root directory for all files concerning this orchestrator. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Also validates that the artifact store and metadata store used are not local.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A StackValidator instance. |
get_docker_image_name(self, pipeline_name)
Returns the full docker image name including registry and tag.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
str |
The full docker image name including registry and tag. |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag.
Args:
pipeline_name: The name of the pipeline.
Returns:
The full docker image name including registry and tag.
"""
base_image_name = f"zenml-vertex:{pipeline_name}"
container_registry = Repository().active_stack.container_registry
if container_registry:
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/{base_image_name}"
return base_image_name
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)
Creates a KFP JSON pipeline.
noqa: DAR402
This is an intermediary representation of the pipeline which is then deployed to Vertex AI Pipelines service.
How it works:
Before this method is called the prepare_pipeline_deployment()
method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (_construct_kfp_pipeline
). The function
kfp.components.load_component_from_text
is used to create the
ContainerOp
, because using the dsl.ContainerOp
class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sorted_steps |
List[BaseStep] |
List of sorted steps. |
required |
pipeline |
BasePipeline |
Zenml Pipeline instance. |
required |
pb2_pipeline |
Pb2Pipeline |
Protobuf Pipeline instance. |
required |
stack |
Stack |
The stack the pipeline was run on. |
required |
runtime_configuration |
RuntimeConfiguration |
The Runtime configuration of the current run. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the attribute |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_or_run_pipeline(
self,
sorted_steps: List["BaseStep"],
pipeline: "BasePipeline",
pb2_pipeline: "Pb2Pipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates a KFP JSON pipeline.
# noqa: DAR402
This is an intermediary representation of the pipeline which is then
deployed to Vertex AI Pipelines service.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()` method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`). The function
`kfp.components.load_component_from_text` is used to create the
`ContainerOp`, because using the `dsl.ContainerOp` class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the
intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for
execution.
Args:
sorted_steps: List of sorted steps.
pipeline: Zenml Pipeline instance.
pb2_pipeline: Protobuf Pipeline instance.
stack: The stack the pipeline was run on.
runtime_configuration: The Runtime configuration of the current run.
Raises:
ValueError: If the attribute `pipeline_root` is not set and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
# If the `pipeline_root` has not been defined in the orchestrator
# configuration,
# try to create it from the artifact store if it is a
# `GCPArtifactStore`.
if not self.pipeline_root:
artifact_store = stack.artifact_store
self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{pipeline.name}/{runtime_configuration.run_name}"
logger.info(
"The attribute `pipeline_root` has not been set in the "
"orchestrator configuration. One has been generated "
"automatically based on the path of the `GCPArtifactStore` "
"artifact store in the stack used to execute the pipeline. "
"The generated `pipeline_root` is `%s`.",
self._pipeline_root,
)
else:
self._pipeline_root = self.pipeline_root
# Build the Docker image that will be used to run the steps of the
# pipeline.
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
def _construct_kfp_pipeline() -> None:
"""Create a `ContainerOp` for each step.
This should contain the name of the Docker image and configures the
entrypoint of the Docker image to run the step.
Additionally, this gives each `ContainerOp` information about its
direct downstream steps.
If this callable is passed to the `compile()` method of
`KFPV2Compiler` all `dsl.ContainerOp` instances will be
automatically added to a singular `dsl.Pipeline` instance.
"""
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step in sorted_steps:
# The command will be needed to eventually call the python step
# within the docker container
command = VertexEntrypointConfiguration.get_entrypoint_command()
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
**{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
)
# Create the `ContainerOp` for the step. Using the
# `dsl.ContainerOp`
# class directly is deprecated when using the Kubeflow SDK v2.
container_op = kfp.components.load_component_from_text(
f"""
name: {step.name}
implementation:
container:
image: {image_name}
command: {command + arguments}"""
)()
# Set upstream tasks as a dependency of the current step
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
self._configure_container_resources(
container_op=container_op,
resource_configuration=step.resource_configuration,
)
step_name_to_container_op[step.name] = container_op
# Save the generated pipeline to a file.
assert runtime_configuration.run_name
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory,
f"{runtime_configuration.run_name}.json",
)
# Compile the pipeline using the Kubeflow SDK V2 compiler that allows
# to generate a JSON representation of the pipeline that can be later
# upload to Vertex AI Pipelines service.
logger.debug(
"Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
"to `%s`",
pipeline_file_path,
)
KFPV2Compiler().compile(
pipeline_func=_construct_kfp_pipeline,
package_path=pipeline_file_path,
pipeline_name=_clean_pipeline_name(pipeline.name),
)
# Using the Google Cloud AIPlatform client, upload and execute the
# pipeline
# on the Vertex AI Pipelines service.
self._upload_and_run_pipeline(
pipeline_name=pipeline.name,
pipeline_file_path=pipeline_file_path,
runtime_configuration=runtime_configuration,
enable_cache=pipeline.enable_cache,
)
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Build a Docker image for the current environment.
This uploads it to a container registry if configured.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
The pipeline to be deployed. |
required |
stack |
Stack |
The stack that will be used to deploy the pipeline. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration for the pipeline. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the container registry is missing. |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Build a Docker image for the current environment.
This uploads it to a container registry if configured.
Args:
pipeline: The pipeline to be deployed.
stack: The stack that will be used to deploy the pipeline.
runtime_configuration: The runtime configuration for the pipeline.
Raises:
RuntimeError: If the container registry is missing.
"""
from zenml.utils import docker_utils
repo = Repository()
container_registry = repo.active_stack.container_registry
if not container_registry:
raise RuntimeError("Missing container registry")
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug(
"Vertex AI Pipelines service docker container requirements %s",
requirements,
)
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
)
container_registry.push_image(image_name)
secrets_manager
special
ZenML integration for GCP Secrets Manager.
The GCP Secrets Manager allows your pipeline to directly access the GCP secrets manager and use the secrets within during runtime.
gcp_secrets_manager
Implementation of the GCP Secrets Manager.
GCPSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the GCP secrets manager.
Attributes:
Name | Type | Description |
---|---|---|
project_id |
str |
This is necessary to access the correct GCP project. The project_id of your GCP project space that contains the Secret Manager. |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
class GCPSecretsManager(BaseSecretsManager):
"""Class to interact with the GCP secrets manager.
Attributes:
project_id: This is necessary to access the correct GCP project.
The project_id of your GCP project space that contains
the Secret Manager.
"""
project_id: str
# Class configuration
FLAVOR: ClassVar[str] = GCP_SECRETS_MANAGER_FLAVOR
SUPPORTS_SCOPING: ClassVar[bool] = True
CLIENT: ClassVar[Any] = None
@classmethod
def _ensure_client_connected(cls) -> None:
if cls.CLIENT is None:
cls.CLIENT = secretmanager.SecretManagerServiceClient()
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
cls.validate_secret_name_or_namespace(namespace)
@classmethod
def validate_secret_name_or_namespace(
cls,
name: str,
) -> None:
"""Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters
and can contain uppercase and lowercase letters, numerals, and the
hyphen (-) and underscore (_) characters. For scoped secrets, we have to
limit the size of the name and namespace even further to allow space for
both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are
also limited by the limitation that Google imposes on label values: max
63 characters and must and must only contain lowercase letters, numerals
and the hyphen (-) and underscore (_) characters
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-z0-9_\-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only lowercase alphanumeric characters and the hyphen (-) and "
f"underscore (_) characters."
)
if name and len(name) > 63:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 63 characters."
)
@property
def parent_name(self) -> str:
"""Construct the GCP parent path to the secret manager.
Returns:
The parent path to the secret manager
"""
return f"projects/{self.project_id}"
def _convert_secret_content(
self, secret: BaseSecretSchema
) -> Dict[str, str]:
"""Convert the secret content into a Google compatible representation.
This method implements two currently supported modes of adapting between
the naming schemas used for ZenML secrets and Google secrets:
* for a scoped Secrets Manager, a Google secret is created for each
ZenML secret with a name that reflects the ZenML secret name and scope
and a value that contains all its key-value pairs in JSON format.
* for an unscoped (i.e. legacy) Secrets Manager, this method creates
multiple Google secret entries for a single ZenML secret by adding the
secret name to the key name of each secret key-value pair. This allows
using the same key across multiple secrets. This is only kept for
backwards compatibility and will be removed some time in the future.
Args:
secret: The ZenML secret
Returns:
A dictionary with the Google secret name as key and the secret
contents as value.
"""
if self.scope == SecretsManagerScope.NONE:
# legacy per-key secret mapping
return {f"{secret.name}_{k}": v for k, v in secret.content.items()}
return {
self._get_scoped_secret_name(
secret.name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
): json.dumps(secret_to_dict(secret)),
}
def _get_secret_labels(
self, secret: BaseSecretSchema
) -> List[Tuple[str, str]]:
"""Return a list of Google secret label values for a given secret.
Args:
secret: the secret object
Returns:
A list of Google secret label values
"""
if self.scope == SecretsManagerScope.NONE:
# legacy per-key secret labels
return [
(ZENML_GROUP_KEY, secret.name),
(ZENML_SCHEMA_NAME, secret.TYPE),
]
metadata = self._get_secret_metadata(secret)
return list(metadata.items())
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> str:
"""Return a Google filter expression for the entire scope or just a scoped secret.
These filters can be used when querying the Google Secrets Manager
for all secrets or for a single secret available in the configured
scope (see https://cloud.google.com/secret-manager/docs/filtering).
Args:
secret_name: Optional secret name to include in the scope metadata.
Returns:
Google filter expression uniquely identifying all secrets
or a named secret within the configured scope.
"""
if self.scope == SecretsManagerScope.NONE:
# legacy per-key secret label filters
if secret_name:
return f"labels.{ZENML_GROUP_KEY}={secret_name}"
else:
return f"labels.{ZENML_GROUP_KEY}:*"
metadata = self._get_secret_scope_metadata(secret_name)
filters = [f"labels.{l}={v}" for (l, v) in metadata.items()]
if secret_name:
filters.append(f"name:{secret_name}")
return " AND ".join(filters)
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected()
set_of_secrets = set()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
if self.scope == SecretsManagerScope.NONE:
name = secret.labels[ZENML_GROUP_KEY]
else:
name = secret.labels[ZENML_SECRET_NAME_LABEL]
# filter by secret name, if one was given
if name and (not secret_name or name == secret_name):
set_of_secrets.add(name)
return list(set_of_secrets)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
gcp_secret = self.CLIENT.create_secret(
request={
"parent": self.parent_name,
"secret_id": k,
"secret": {
"replication": {"automatic": {}},
"labels": self._get_secret_labels(secret),
},
}
)
logger.debug("Created empty secret: %s", gcp_secret.name)
self.CLIENT.add_secret_version(
request={
"parent": gcp_secret.name,
"payload": {"data": str(v).encode()},
}
)
logger.debug("Added value to secret.")
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
zenml_secret: Optional[BaseSecretSchema] = None
if self.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Google secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
# List all secrets.
for google_secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
secret_version_name = google_secret.name + "/versions/latest"
response = self.CLIENT.access_secret_version(
request={"name": secret_version_name}
)
secret_value = response.payload.data.decode("UTF-8")
secret_key = remove_group_name_from_key(
google_secret.name.split("/")[-1], secret_name
)
secret_contents[secret_key] = secret_value
zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]
if not secret_contents:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Google secrets
google_secret_name = self.CLIENT.secret_path(
self.project_id,
self._get_scoped_secret_name(
secret_name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
),
)
try:
# fetch the latest secret version
google_secret = self.CLIENT.get_secret(name=google_secret_name)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
# make sure the secret has the correct scope labels to filter out
# unscoped secrets with similar names
scope_labels = self._get_secret_scope_metadata(secret_name)
# all scope labels need to be included in the google secret labels,
# otherwise the secret does not belong to the current scope
if not scope_labels.items() <= google_secret.labels.items():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
try:
# fetch the latest secret version
response = self.CLIENT.access_secret_version(
name=f"{google_secret_name}/versions/latest"
)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_value = response.payload.data.decode("UTF-8")
zenml_secret = secret_from_dict(
json.loads(secret_value), secret_name=secret_name
)
return zenml_secret
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
google_secret_name = self.CLIENT.secret_path(self.project_id, k)
payload = {"data": str(v).encode()}
self.CLIENT.add_secret_version(
request={"parent": google_secret_name, "payload": payload}
)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
# Go through all gcp secrets and delete the ones with the secret_name
# as label.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
self.CLIENT.delete_secret(request={"name": secret.name})
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(),
}
):
logger.info(f"Deleting Google secret {secret.name}")
self.CLIENT.delete_secret(request={"name": secret.name})
parent_name: str
property
readonly
Construct the GCP parent path to the secret manager.
Returns:
Type | Description |
---|---|
str |
The parent path to the secret manager |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(),
}
):
logger.info(f"Deleting Google secret {secret.name}")
self.CLIENT.delete_secret(request={"name": secret.name})
delete_secret(self, secret_name)
Delete an existing secret by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret no longer exists |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
# Go through all gcp secrets and delete the ones with the secret_name
# as label.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
self.CLIENT.delete_secret(request={"name": secret.name})
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Get a secret by its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
zenml_secret: Optional[BaseSecretSchema] = None
if self.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Google secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
# List all secrets.
for google_secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
secret_version_name = google_secret.name + "/versions/latest"
response = self.CLIENT.access_secret_version(
request={"name": secret_version_name}
)
secret_value = response.payload.data.decode("UTF-8")
secret_key = remove_group_name_from_key(
google_secret.name.split("/")[-1], secret_name
)
secret_contents[secret_key] = secret_value
zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]
if not secret_contents:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Google secrets
google_secret_name = self.CLIENT.secret_path(
self.project_id,
self._get_scoped_secret_name(
secret_name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
),
)
try:
# fetch the latest secret version
google_secret = self.CLIENT.get_secret(name=google_secret_name)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
# make sure the secret has the correct scope labels to filter out
# unscoped secrets with similar names
scope_labels = self._get_secret_scope_metadata(secret_name)
# all scope labels need to be included in the google secret labels,
# otherwise the secret does not belong to the current scope
if not scope_labels.items() <= google_secret.labels.items():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
try:
# fetch the latest secret version
response = self.CLIENT.access_secret_version(
name=f"{google_secret_name}/versions/latest"
)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_value = response.payload.data.decode("UTF-8")
zenml_secret = secret_from_dict(
json.loads(secret_value), secret_name=secret_name
)
return zenml_secret
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
gcp_secret = self.CLIENT.create_secret(
request={
"parent": self.parent_name,
"secret_id": k,
"secret": {
"replication": {"automatic": {}},
"labels": self._get_secret_labels(secret),
},
}
)
logger.debug("Created empty secret: %s", gcp_secret.name)
self.CLIENT.add_secret_version(
request={
"parent": gcp_secret.name,
"payload": {"data": str(v).encode()},
}
)
logger.debug("Added value to secret.")
update_secret(self, secret)
Update an existing secret by creating new versions of the existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
google_secret_name = self.CLIENT.secret_path(self.project_id, k)
payload = {"data": str(v).encode()}
self.CLIENT.add_secret_version(
request={"parent": google_secret_name, "payload": payload}
)
validate_secret_name_or_namespace(name)
classmethod
Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters and can contain uppercase and lowercase letters, numerals, and the hyphen (-) and underscore (_) characters. For scoped secrets, we have to limit the size of the name and namespace even further to allow space for both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are also limited by the limitation that Google imposes on label values: max 63 characters and must and must only contain lowercase letters, numerals and the hyphen (-) and underscore (_) characters
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(
cls,
name: str,
) -> None:
"""Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters
and can contain uppercase and lowercase letters, numerals, and the
hyphen (-) and underscore (_) characters. For scoped secrets, we have to
limit the size of the name and namespace even further to allow space for
both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are
also limited by the limitation that Google imposes on label values: max
63 characters and must and must only contain lowercase letters, numerals
and the hyphen (-) and underscore (_) characters
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-z0-9_\-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only lowercase alphanumeric characters and the hyphen (-) and "
f"underscore (_) characters."
)
if name and len(name) > 63:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 63 characters."
)
remove_group_name_from_key(combined_key_name, group_name)
Removes the secret group name from the secret key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
combined_key_name |
str |
Full name as it is within the gcp secrets manager |
required |
group_name |
str |
Group name (the ZenML Secret name) |
required |
Returns:
Type | Description |
---|---|
str |
The cleaned key |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the group name is not found in the key |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
"""Removes the secret group name from the secret key.
Args:
combined_key_name: Full name as it is within the gcp secrets manager
group_name: Group name (the ZenML Secret name)
Returns:
The cleaned key
Raises:
RuntimeError: If the group name is not found in the key
"""
if combined_key_name.startswith(group_name + "_"):
return combined_key_name[len(group_name + "_") :]
else:
raise RuntimeError(
f"Key-name `{combined_key_name}` does not have the "
f"prefix `{group_name}`. Key could not be "
f"extracted."
)
step_operators
special
Initialization for the VertexAI Step Operator.
vertex_step_operator
Implementation of a VertexAI step operator.
Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py
VertexStepOperator (BaseStepOperator, GoogleCredentialsMixin)
pydantic-model
Step operator to run a step on Vertex AI.
This class defines code that can set up a Vertex AI environment and run the ZenML entrypoint command in it.
Attributes:
Name | Type | Description |
---|---|---|
region |
str |
Region name, e.g., |
project |
Optional[str] |
GCP project name. If left None, inferred from the environment. |
accelerator_type |
Optional[str] |
Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType |
accelerator_count |
int |
Defines number of accelerators to be used for the job. |
machine_type |
str |
Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types |
base_image |
Optional[str] |
Base image for building the custom job container. |
encryption_spec_key_name |
Optional[str] |
Encryption spec key name. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
"""Step operator to run a step on Vertex AI.
This class defines code that can set up a Vertex AI environment and run the
ZenML entrypoint command in it.
Attributes:
region: Region name, e.g., `europe-west1`.
project: GCP project name. If left None, inferred from the
environment.
accelerator_type: Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType
accelerator_count: Defines number of accelerators to be
used for the job.
machine_type: Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
base_image: Base image for building the custom job container.
encryption_spec_key_name: Encryption spec key name.
"""
region: str
project: Optional[str] = None
accelerator_type: Optional[str] = None
accelerator_count: int = 0
machine_type: str = "n1-standard-4"
base_image: Optional[str] = None
# customer managed encryption key resource name
# will be applied to all Vertex AI resources if set
encryption_spec_key_name: Optional[str] = None
# Class configuration
FLAVOR: ClassVar[str] = GCP_VERTEX_STEP_OPERATOR_FLAVOR
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Returns:
StackValidator: Validator for the stack.
"""
def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
# For now this only works on local orchestrator and GCP artifact
# store
return (
(
stack.orchestrator.FLAVOR == "local"
and stack.artifact_store.FLAVOR == "gcp"
),
"Only local orchestrator and GCP artifact store are currently "
"supported",
)
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_ensure_local_orchestrator,
)
@property_validator("accelerator_type")
def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
"""Validates that the accelerator type is valid.
Args:
accelerator_type: Accelerator type
Raises:
ValueError: If the accelerator type is not valid.
"""
accepted_vals = list(
aiplatform.gapic.AcceleratorType.__members__.keys()
)
if accelerator_type and accelerator_type.upper() not in accepted_vals:
raise ValueError(
f"Accelerator must be one of the following: {accepted_vals}"
)
def _build_and_push_docker_image(
self,
pipeline_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> str:
"""Builds and pushes a docker image.
Args:
pipeline_name: Pipeline name
requirements: Requirements
entrypoint_command: Entrypoint command
Returns:
Docker image name
Raises:
RuntimeError: If no container registry is found in the stack.
"""
repo = Repository()
container_registry = repo.active_stack.container_registry
if not container_registry:
raise RuntimeError("Missing container registry")
registry_uri = container_registry.uri.rstrip("/")
image_name = f"{registry_uri}/zenml-vertex:{pipeline_name}"
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
entrypoint=" ".join(entrypoint_command),
requirements=set(requirements),
base_image=self.base_image,
)
container_registry.push_image(image_name)
return docker_utils.get_image_digest(image_name) or image_name
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on Vertex AI.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
resource_configuration: The resource configuration for this step.
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
if resource_configuration.cpu_count or resource_configuration.memory:
logger.warning(
"Specifying cpus or memory is not supported for "
"the Vertex step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different machine_type type like this: "
"`zenml step-operator update %s "
"--machine_type=<MACHINE_TYPE>`",
self.name,
)
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
if self.project:
if self.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this orchestrator is "
"configured to use the project `%s`.",
project_id,
self.project,
)
else:
self.project = project_id
# Step 2: Build and push image
image_name = self._build_and_push_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
accelerator_count = (
resource_configuration.gpu_count or self.accelerator_count
)
custom_job = {
"display_name": run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": self.machine_type,
"accelerator_type": self.accelerator_type,
"accelerator_count": accelerator_count
if self.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": [],
"args": [],
},
}
]
},
"labels": job_labels,
"encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
if self.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = f"projects/{self.project}/locations/{self.region}"
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state
# periodically, and retries the polling when a transient connectivity
# issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELED are considered to be final states.
# The following logic will keep polling the state of the job until
# the job enters a final state.
#
# During the polling, if a connection error was encountered, the GET
# request will be retried by recreating the Python API client to
# refresh the lifecycle of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
# will raise ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Returns:
Type | Description |
---|---|
StackValidator |
Validator for the stack. |
launch(self, pipeline_name, run_name, requirements, entrypoint_command, resource_configuration)
Launches a step on Vertex AI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
resource_configuration |
ResourceConfiguration |
The resource configuration for this step. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run fails. |
ConnectionError |
If the run fails due to a connection error. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on Vertex AI.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
resource_configuration: The resource configuration for this step.
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
if resource_configuration.cpu_count or resource_configuration.memory:
logger.warning(
"Specifying cpus or memory is not supported for "
"the Vertex step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different machine_type type like this: "
"`zenml step-operator update %s "
"--machine_type=<MACHINE_TYPE>`",
self.name,
)
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
if self.project:
if self.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this orchestrator is "
"configured to use the project `%s`.",
project_id,
self.project,
)
else:
self.project = project_id
# Step 2: Build and push image
image_name = self._build_and_push_docker_image(
pipeline_name=pipeline_name,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
accelerator_count = (
resource_configuration.gpu_count or self.accelerator_count
)
custom_job = {
"display_name": run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": self.machine_type,
"accelerator_type": self.accelerator_type,
"accelerator_count": accelerator_count
if self.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": [],
"args": [],
},
}
]
},
"labels": job_labels,
"encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
if self.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = f"projects/{self.project}/locations/{self.region}"
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state
# periodically, and retries the polling when a transient connectivity
# issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELED are considered to be final states.
# The following logic will keep polling the state of the job until
# the job enters a final state.
#
# During the polling, if a connection error was encountered, the GET
# request will be retried by recreating the Python API client to
# refresh the lifecycle of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
# will raise ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
validate_accelerator_enum(accelerator_type)
classmethod
Validates that the accelerator type is valid.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
accelerator_type |
Optional[str] |
Accelerator type |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the accelerator type is not valid. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
@property_validator("accelerator_type")
def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
"""Validates that the accelerator type is valid.
Args:
accelerator_type: Accelerator type
Raises:
ValueError: If the accelerator type is not valid.
"""
accepted_vals = list(
aiplatform.gapic.AcceleratorType.__members__.keys()
)
if accelerator_type and accelerator_type.upper() not in accepted_vals:
raise ValueError(
f"Accelerator must be one of the following: {accepted_vals}"
)
github
special
Initialization of the GitHub ZenML integration.
The GitHub integration provides a way to orchestrate pipelines using GitHub Actions.
GitHubIntegration (Integration)
Definition of GitHub integration for ZenML.
Source code in zenml/integrations/github/__init__.py
class GitHubIntegration(Integration):
"""Definition of GitHub integration for ZenML."""
NAME = GITHUB
REQUIREMENTS: List[str] = ["PyNaCl~=1.5.0"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the GitHub integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=GITHUB_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.github.orchestrators.GitHubActionsOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
FlavorWrapper(
name=GITHUB_SECRET_MANAGER_FLAVOR,
source="zenml.integrations.github.secrets_managers.GitHubSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the GitHub integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/github/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the GitHub integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=GITHUB_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.github.orchestrators.GitHubActionsOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
FlavorWrapper(
name=GITHUB_SECRET_MANAGER_FLAVOR,
source="zenml.integrations.github.secrets_managers.GitHubSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
]
orchestrators
special
Initialization of the GitHub Actions Orchestrator.
github_actions_entrypoint_configuration
Implementation of the GitHub Actions Orchestrator entrypoint.
GitHubActionsEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on GitHub Actions runners.
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
class GitHubActionsEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on GitHub Actions runners."""
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""GitHub Actions specific entrypoint options.
Returns:
Set with the custom run id option.
"""
return {RUN_ID_OPTION}
@classmethod
def get_custom_entrypoint_arguments(
cls, step: BaseStep, **kwargs: Any
) -> List[str]:
"""Adds a run id argument for the entrypoint.
Args:
step: Step for which the arguments are passed.
**kwargs: Additional keyword arguments.
Returns:
GitHub Actions placeholder for the run id option.
"""
# These placeholders in the workflow file will be replaced with
# concrete values by the GitHub Actions runner
run_id = (
"${{ github.run_id }}_${{ github.run_number }}_"
"${{ github.run_attempt }}"
)
return [f"--{RUN_ID_OPTION}", run_id]
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the pipeline run name.
Args:
pipeline_name: Name of the pipeline which will run.
Returns:
The run name.
"""
run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
return f"{pipeline_name}-{run_id}"
get_custom_entrypoint_arguments(step, **kwargs)
classmethod
Adds a run id argument for the entrypoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
Step for which the arguments are passed. |
required |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
GitHub Actions placeholder for the run id option. |
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
cls, step: BaseStep, **kwargs: Any
) -> List[str]:
"""Adds a run id argument for the entrypoint.
Args:
step: Step for which the arguments are passed.
**kwargs: Additional keyword arguments.
Returns:
GitHub Actions placeholder for the run id option.
"""
# These placeholders in the workflow file will be replaced with
# concrete values by the GitHub Actions runner
run_id = (
"${{ github.run_id }}_${{ github.run_number }}_"
"${{ github.run_attempt }}"
)
return [f"--{RUN_ID_OPTION}", run_id]
get_custom_entrypoint_options()
classmethod
GitHub Actions specific entrypoint options.
Returns:
Type | Description |
---|---|
Set[str] |
Set with the custom run id option. |
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""GitHub Actions specific entrypoint options.
Returns:
Set with the custom run id option.
"""
return {RUN_ID_OPTION}
get_run_name(self, pipeline_name)
Returns the pipeline run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which will run. |
required |
Returns:
Type | Description |
---|---|
str |
The run name. |
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the pipeline run name.
Args:
pipeline_name: Name of the pipeline which will run.
Returns:
The run name.
"""
run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
return f"{pipeline_name}-{run_id}"
github_actions_orchestrator
Implementation of the GitHub Actions Orchestrator.
GitHubActionsOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines using GitHub Actions.
Attributes:
Name | Type | Description |
---|---|---|
custom_docker_base_image_name |
Optional[str] |
Name of a docker image that should be used as the base for the image that will be run on GitHub Action runners. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/ |
skip_dirty_repository_check |
bool |
If |
skip_github_repository_check |
bool |
If |
push |
bool |
If |
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
class GitHubActionsOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using GitHub Actions.
Attributes:
custom_docker_base_image_name: Name of a docker image that should be
used as the base for the image that will be run on GitHub Action
runners. If no custom image is given, a basic image of the active
ZenML version will be used. **Note**: This image needs to have
ZenML installed, otherwise the pipeline execution will fail. For
that reason, you might want to extend the ZenML docker images
found here: https://hub.docker.com/r/zenmldocker/zenml/
skip_dirty_repository_check: If `True`, this orchestrator will not
raise an exception when trying to run a pipeline while there are
still untracked/uncommitted files in the git repository.
skip_github_repository_check: If `True`, the orchestrator will not check
if your git repository is pointing to a GitHub remote.
push: If `True`, this orchestrator will automatically commit and push
the GitHub workflow file when running a pipeline. If `False`, the
workflow file will be written to the correct location but needs to
be committed and pushed manually.
"""
custom_docker_base_image_name: Optional[str] = None
skip_dirty_repository_check: bool = False
skip_github_repository_check: bool = False
push: bool = False
_git_repo: Optional[Repo] = None
# Class configuration
FLAVOR: ClassVar[str] = GITHUB_ORCHESTRATOR_FLAVOR
@property
def git_repo(self) -> Repo:
"""Returns the git repository for the current working directory.
Returns:
Git repository for the current working directory.
Raises:
RuntimeError: If there is no git repository for the current working
directory or the repository remote is not pointing to GitHub.
"""
if not self._git_repo:
try:
self._git_repo = Repo(search_parent_directories=True)
except InvalidGitRepositoryError:
raise RuntimeError(
"Unable to find git repository in current working "
f"directory {os.getcwd()} or its parent directories."
)
remote_url = self.git_repo.remote().url
is_github_repo = any(
remote_url.startswith(prefix)
for prefix in GITHUB_REMOTE_URL_PREFIXES
)
if not (is_github_repo or self.skip_github_repository_check):
raise RuntimeError(
f"The remote URL '{remote_url}' of your git repo "
f"({self._git_repo.git_dir}) is not pointing to a GitHub "
"repository. The GitHub Actions orchestrator runs "
"pipelines using GitHub Actions and therefore only works "
"with GitHub repositories. If you want to skip this check "
"and run this orchestrator anyway, run: \n"
f"`zenml orchestrator update {self.name} "
"--skip_github_repository_check=true`"
)
return self._git_repo
@property
def workflow_directory(self) -> str:
"""Returns path to the GitHub workflows directory.
Returns:
The GitHub workflows directory.
"""
assert self.git_repo.working_dir
return os.path.join(self.git_repo.working_dir, ".github", "workflows")
@property
def validator(self) -> Optional[StackValidator]:
"""Validator that ensures that the stack is compatible.
Makes sure that the stack contains a container registry and only
remote components.
Returns:
The stack validator.
"""
def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.is_local:
return False, (
"The GitHub Actions orchestrator requires a remote "
f"container registry, but the '{container_registry.name}' "
"container registry of your active stack points to a local "
f"URI '{container_registry.uri}'. Please make sure stacks "
"with a GitHub Actions orchestrator always contain remote "
"container registries."
)
if container_registry.requires_authentication:
return False, (
"The GitHub Actions orchestrator currently only works with "
"GitHub container registries or public container "
f"registries, but your {container_registry.FLAVOR} "
f"container registry '{container_registry.name}' requires "
"authentication."
)
for component in stack.components.values():
if component.local_path:
return False, (
"The GitHub Actions orchestrator runs pipelines on "
"remote GitHub Actions runners, but the "
f"'{component.name}' {component.TYPE.value} of your "
"active stack is a local component. Please make sure "
"to only use remote stack components in combination "
"with the GitHub Actions orchestrator. "
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_local_requirements,
)
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag.
Args:
pipeline_name: Name of the pipeline for which to generate a docker
image name.
Returns:
The docker image name.
"""
container_registry = Repository().active_stack.container_registry
assert container_registry # should never happen due to validation
return f"{container_registry.uri}/zenml-github-actions:{pipeline_name}"
def _docker_login_step(
self,
container_registry: BaseContainerRegistry,
) -> Optional[Dict[str, Any]]:
"""GitHub Actions step for authenticating with the container registry.
Args:
container_registry: The container registry which (potentially)
requires a step to authenticate.
Returns:
Dictionary specifying the GitHub Actions step for authenticating
with the container registry if that is required, `None` otherwise.
"""
if (
isinstance(container_registry, GitHubContainerRegistry)
and container_registry.automatic_token_authentication
):
# Use GitHub Actions specific placeholder if the container registry
# specifies automatic token authentication
username = "${{ github.actor }}"
password = "${{ secrets.GITHUB_TOKEN }}"
# TODO: Uncomment these lines once we support different private
# container registries in GitHub Actions
# elif container_registry.requires_authentication:
# username = cast(str, container_registry.username)
# password = cast(str, container_registry.password)
else:
return None
return {
"name": "Authenticate with the container registry",
"uses": DOCKER_LOGIN_ACTION,
"with": {
"registry": container_registry.uri,
"username": username,
"password": password,
},
}
def _write_environment_file_step(
self,
file_name: str,
secrets_manager: Optional[BaseSecretsManager] = None,
) -> Optional[Dict[str, Any]]:
"""GitHub Actions step for writing secrets to an environment file.
Args:
file_name: Name of the environment file that should be written.
secrets_manager: Secrets manager that will be used to read secrets
during pipeline execution.
Returns:
Dictionary specifying the GitHub Actions step for writing the
environment file.
"""
if not isinstance(secrets_manager, GitHubSecretsManager):
return None
# Always include the environment variable that specifies whether
# we're running in a GitHub Action workflow so the secret manager knows
# how to query secret values
command = (
f'echo {ENV_IN_GITHUB_ACTIONS}="${ENV_IN_GITHUB_ACTIONS}" '
f"> {file_name}; "
)
# Write all ZenML secrets into the environment file. Explicitly writing
# these `${{ secrets.<SECRET_NAME> }}` placeholders into the workflow
# yaml is the only way for us to access the GitHub secrets in a GitHub
# Actions workflow.
append_secret_placeholder = (
"echo {secret_name}=${{{{ secrets.{secret_name} }}}} >> {file}; "
)
for secret_name in secrets_manager.get_all_secret_keys(
include_prefix=True
):
command += append_secret_placeholder.format(
secret_name=secret_name, file=file_name
)
return {
"name": "Write environment file",
"run": command,
}
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Builds and uploads a docker image.
Args:
pipeline: The pipeline for which the image is built.
stack: The stack on which the pipeline will be executed.
runtime_configuration: Runtime configuration for the pipeline run.
Raises:
RuntimeError: If the orchestrator should only run in a clean git
repository and the repository is dirty.
"""
if not self.skip_dirty_repository_check and self.git_repo.is_dirty(
untracked_files=True
):
raise RuntimeError(
"Trying to run a pipeline from within a dirty (=containing "
"untracked/uncommitted files) git repository."
"If you want this orchestrator to skip the dirty repo check in "
f"the future, run\n `zenml orchestrator update {self.name} "
"--skip_dirty_repository_check=true`"
)
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug(
"Github actions docker image requirements: %s", requirements
)
docker_utils.build_docker_image(
build_context_path=source_utils.get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
)
assert stack.container_registry # should never happen due to validation
stack.container_registry.push_image(image_name)
# Store the docker image digest in the runtime configuration so it gets
# tracked in the ZenStore
image_digest = docker_utils.get_image_digest(image_name) or image_name
runtime_configuration["docker_image"] = image_digest
def prepare_or_run_pipeline(
self,
sorted_steps: List[BaseStep],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Writes a GitHub Action workflow yaml and optionally pushes it.
Args:
sorted_steps: List of sorted steps
pipeline: Zenml Pipeline instance
pb2_pipeline: Protobuf Pipeline instance
stack: The stack the pipeline was run on
runtime_configuration: The Runtime configuration of the current run
Raises:
ValueError: If a schedule without a cron expression or with an
invalid cron expression is passed.
"""
schedule = runtime_configuration.schedule
workflow_name = pipeline.name
if schedule:
# Add a suffix to the workflow filename so we don't overwrite
# scheduled pipeline by future schedules or single pipeline runs.
datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
workflow_name += f"-scheduled-{datetime_string}"
workflow_path = os.path.join(
self.workflow_directory,
f"{workflow_name}.yaml",
)
# Store the encoded pb2 pipeline once as an environment variable.
# We will replace the entrypoint argument later to reduce the size
# of the workflow file.
encoded_pb2_pipeline = string_utils.b64_encode(
json_format.MessageToJson(pb2_pipeline)
)
workflow_dict: Dict[str, Any] = {
"name": workflow_name,
"env": {ENV_ENCODED_ZENML_PIPELINE: encoded_pb2_pipeline},
}
if schedule:
if not schedule.cron_expression:
raise ValueError(
"GitHub Action workflows can only be scheduled using cron "
"expressions and not using a periodic schedule. If you "
"want to schedule pipelines using this GitHub Action "
"orchestrator, please include a cron expression in your "
"schedule object. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
# GitHub workflows requires a schedule interval of at least 5
# minutes. Invalid cron expressions would be something like
# `*/3 * * * *` (all stars except the first part of the expression,
# which will have the format `*/minute_interval`)
if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
raise ValueError(
"GitHub workflows requires a schedule interval of at "
"least 5 minutes which is incompatible with your cron "
f"expression '{schedule.cron_expression}'. An example of a "
"valid cron expression would be '* 1 * * *' to run "
"every hour. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
logger.warning(
"GitHub only runs scheduled workflows once the "
"workflow file is merged to the default branch of the "
"repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
"Please make sure to merge your current branch into the "
"default branch for this scheduled pipeline to run."
)
workflow_dict["on"] = {
"schedule": [{"cron": schedule.cron_expression}]
}
else:
# The pipeline should only run once. The only fool-proof way to
# only execute a workflow once seems to be running on specific tags.
# We don't want to create tags for each pipeline run though, so
# instead we only run this workflow if the workflow file is
# modified. As long as users don't manually modify these files this
# should be sufficient.
workflow_path_in_repo = os.path.relpath(
workflow_path, self.git_repo.working_dir
)
workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}
image_name = self.get_docker_image_name(pipeline.name)
image_name = docker_utils.get_image_digest(image_name) or image_name
# Prepare the step that writes an environment file which will get
# passed to the docker image
env_file_name = ".zenml_docker_env"
write_env_file_step = self._write_environment_file_step(
file_name=env_file_name, secrets_manager=stack.secrets_manager
)
docker_run_args = (
["--env-file", env_file_name] if write_env_file_step else []
)
# Prepare the docker login step if necessary
container_registry = stack.container_registry
assert container_registry
docker_login_step = self._docker_login_step(container_registry)
# The base command that each job will execute with specific arguments
base_command = [
"docker",
"run",
*docker_run_args,
image_name,
] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()
jobs = {}
for step in sorted_steps:
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the "
"GitHub Actions orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
job_steps = []
# Copy the shared dicts here to avoid creating yaml anchors (which
# are currently not supported in GitHub workflow yaml files)
if write_env_file_step:
job_steps.append(copy.deepcopy(write_env_file_step))
if docker_login_step:
job_steps.append(copy.deepcopy(docker_login_step))
entrypoint_args = (
GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
)
)
# Replace the encoded string by a global environment variable to
# keep the workflow file small
index = entrypoint_args.index(f"--{PIPELINE_JSON_OPTION}")
entrypoint_args[index + 1] = f"${ENV_ENCODED_ZENML_PIPELINE}"
command = base_command + entrypoint_args
docker_run_step = {
"name": "Run the docker image",
"run": " ".join(command),
}
job_steps.append(docker_run_step)
job_dict = {
"runs-on": "ubuntu-latest",
"needs": self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
),
"steps": job_steps,
}
jobs[step.name] = job_dict
workflow_dict["jobs"] = jobs
fileio.makedirs(self.workflow_directory)
yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
logger.info("Wrote GitHub workflow file to %s", workflow_path)
if self.push:
# Add, commit and push the pipeline workflow yaml
self.git_repo.index.add(workflow_path)
self.git_repo.index.commit(
"[ZenML GitHub Actions Orchestrator] Add github workflow for "
f"pipeline {pipeline.name}."
)
self.git_repo.remote().push()
logger.info("Pushed workflow file '%s'", workflow_path)
else:
logger.info(
"Automatically committing and pushing is disabled for this "
"orchestrator. To run the pipeline, you'll have to commit and "
"push the workflow file '%s' manually.\n"
"If you want to update this orchestrator to automatically "
"commit and push in the future, run "
"`zenml orchestrator update %s --push=true`",
workflow_path,
self.name,
)
git_repo: Repo
property
readonly
Returns the git repository for the current working directory.
Returns:
Type | Description |
---|---|
Repo |
Git repository for the current working directory. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If there is no git repository for the current working directory or the repository remote is not pointing to GitHub. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validator that ensures that the stack is compatible.
Makes sure that the stack contains a container registry and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
The stack validator. |
workflow_directory: str
property
readonly
Returns path to the GitHub workflows directory.
Returns:
Type | Description |
---|---|
str |
The GitHub workflows directory. |
get_docker_image_name(self, pipeline_name)
Returns the full docker image name including registry and tag.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to generate a docker image name. |
required |
Returns:
Type | Description |
---|---|
str |
The docker image name. |
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag.
Args:
pipeline_name: Name of the pipeline for which to generate a docker
image name.
Returns:
The docker image name.
"""
container_registry = Repository().active_stack.container_registry
assert container_registry # should never happen due to validation
return f"{container_registry.uri}/zenml-github-actions:{pipeline_name}"
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)
Writes a GitHub Action workflow yaml and optionally pushes it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sorted_steps |
List[zenml.steps.base_step.BaseStep] |
List of sorted steps |
required |
pipeline |
BasePipeline |
Zenml Pipeline instance |
required |
pb2_pipeline |
Pipeline |
Protobuf Pipeline instance |
required |
stack |
Stack |
The stack the pipeline was run on |
required |
runtime_configuration |
RuntimeConfiguration |
The Runtime configuration of the current run |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If a schedule without a cron expression or with an invalid cron expression is passed. |
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_or_run_pipeline(
self,
sorted_steps: List[BaseStep],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Writes a GitHub Action workflow yaml and optionally pushes it.
Args:
sorted_steps: List of sorted steps
pipeline: Zenml Pipeline instance
pb2_pipeline: Protobuf Pipeline instance
stack: The stack the pipeline was run on
runtime_configuration: The Runtime configuration of the current run
Raises:
ValueError: If a schedule without a cron expression or with an
invalid cron expression is passed.
"""
schedule = runtime_configuration.schedule
workflow_name = pipeline.name
if schedule:
# Add a suffix to the workflow filename so we don't overwrite
# scheduled pipeline by future schedules or single pipeline runs.
datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
workflow_name += f"-scheduled-{datetime_string}"
workflow_path = os.path.join(
self.workflow_directory,
f"{workflow_name}.yaml",
)
# Store the encoded pb2 pipeline once as an environment variable.
# We will replace the entrypoint argument later to reduce the size
# of the workflow file.
encoded_pb2_pipeline = string_utils.b64_encode(
json_format.MessageToJson(pb2_pipeline)
)
workflow_dict: Dict[str, Any] = {
"name": workflow_name,
"env": {ENV_ENCODED_ZENML_PIPELINE: encoded_pb2_pipeline},
}
if schedule:
if not schedule.cron_expression:
raise ValueError(
"GitHub Action workflows can only be scheduled using cron "
"expressions and not using a periodic schedule. If you "
"want to schedule pipelines using this GitHub Action "
"orchestrator, please include a cron expression in your "
"schedule object. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
# GitHub workflows requires a schedule interval of at least 5
# minutes. Invalid cron expressions would be something like
# `*/3 * * * *` (all stars except the first part of the expression,
# which will have the format `*/minute_interval`)
if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
raise ValueError(
"GitHub workflows requires a schedule interval of at "
"least 5 minutes which is incompatible with your cron "
f"expression '{schedule.cron_expression}'. An example of a "
"valid cron expression would be '* 1 * * *' to run "
"every hour. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
logger.warning(
"GitHub only runs scheduled workflows once the "
"workflow file is merged to the default branch of the "
"repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
"Please make sure to merge your current branch into the "
"default branch for this scheduled pipeline to run."
)
workflow_dict["on"] = {
"schedule": [{"cron": schedule.cron_expression}]
}
else:
# The pipeline should only run once. The only fool-proof way to
# only execute a workflow once seems to be running on specific tags.
# We don't want to create tags for each pipeline run though, so
# instead we only run this workflow if the workflow file is
# modified. As long as users don't manually modify these files this
# should be sufficient.
workflow_path_in_repo = os.path.relpath(
workflow_path, self.git_repo.working_dir
)
workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}
image_name = self.get_docker_image_name(pipeline.name)
image_name = docker_utils.get_image_digest(image_name) or image_name
# Prepare the step that writes an environment file which will get
# passed to the docker image
env_file_name = ".zenml_docker_env"
write_env_file_step = self._write_environment_file_step(
file_name=env_file_name, secrets_manager=stack.secrets_manager
)
docker_run_args = (
["--env-file", env_file_name] if write_env_file_step else []
)
# Prepare the docker login step if necessary
container_registry = stack.container_registry
assert container_registry
docker_login_step = self._docker_login_step(container_registry)
# The base command that each job will execute with specific arguments
base_command = [
"docker",
"run",
*docker_run_args,
image_name,
] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()
jobs = {}
for step in sorted_steps:
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the "
"GitHub Actions orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
job_steps = []
# Copy the shared dicts here to avoid creating yaml anchors (which
# are currently not supported in GitHub workflow yaml files)
if write_env_file_step:
job_steps.append(copy.deepcopy(write_env_file_step))
if docker_login_step:
job_steps.append(copy.deepcopy(docker_login_step))
entrypoint_args = (
GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
)
)
# Replace the encoded string by a global environment variable to
# keep the workflow file small
index = entrypoint_args.index(f"--{PIPELINE_JSON_OPTION}")
entrypoint_args[index + 1] = f"${ENV_ENCODED_ZENML_PIPELINE}"
command = base_command + entrypoint_args
docker_run_step = {
"name": "Run the docker image",
"run": " ".join(command),
}
job_steps.append(docker_run_step)
job_dict = {
"runs-on": "ubuntu-latest",
"needs": self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
),
"steps": job_steps,
}
jobs[step.name] = job_dict
workflow_dict["jobs"] = jobs
fileio.makedirs(self.workflow_directory)
yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
logger.info("Wrote GitHub workflow file to %s", workflow_path)
if self.push:
# Add, commit and push the pipeline workflow yaml
self.git_repo.index.add(workflow_path)
self.git_repo.index.commit(
"[ZenML GitHub Actions Orchestrator] Add github workflow for "
f"pipeline {pipeline.name}."
)
self.git_repo.remote().push()
logger.info("Pushed workflow file '%s'", workflow_path)
else:
logger.info(
"Automatically committing and pushing is disabled for this "
"orchestrator. To run the pipeline, you'll have to commit and "
"push the workflow file '%s' manually.\n"
"If you want to update this orchestrator to automatically "
"commit and push in the future, run "
"`zenml orchestrator update %s --push=true`",
workflow_path,
self.name,
)
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Builds and uploads a docker image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
The pipeline for which the image is built. |
required |
stack |
Stack |
The stack on which the pipeline will be executed. |
required |
runtime_configuration |
RuntimeConfiguration |
Runtime configuration for the pipeline run. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the orchestrator should only run in a clean git repository and the repository is dirty. |
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Builds and uploads a docker image.
Args:
pipeline: The pipeline for which the image is built.
stack: The stack on which the pipeline will be executed.
runtime_configuration: Runtime configuration for the pipeline run.
Raises:
RuntimeError: If the orchestrator should only run in a clean git
repository and the repository is dirty.
"""
if not self.skip_dirty_repository_check and self.git_repo.is_dirty(
untracked_files=True
):
raise RuntimeError(
"Trying to run a pipeline from within a dirty (=containing "
"untracked/uncommitted files) git repository."
"If you want this orchestrator to skip the dirty repo check in "
f"the future, run\n `zenml orchestrator update {self.name} "
"--skip_dirty_repository_check=true`"
)
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug(
"Github actions docker image requirements: %s", requirements
)
docker_utils.build_docker_image(
build_context_path=source_utils.get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
)
assert stack.container_registry # should never happen due to validation
stack.container_registry.push_image(image_name)
# Store the docker image digest in the runtime configuration so it gets
# tracked in the ZenStore
image_digest = docker_utils.get_image_digest(image_name) or image_name
runtime_configuration["docker_image"] = image_digest
secrets_managers
special
Initialization of the GitHub Secrets Manager.
github_secrets_manager
Implementation of the GitHub Secrets Manager.
GitHubSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the GitHub secrets manager.
Attributes:
Name | Type | Description |
---|---|---|
owner |
str |
The owner (either individual or organization) of the repository. |
repository |
str |
Name of the GitHub repository. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
class GitHubSecretsManager(BaseSecretsManager):
"""Class to interact with the GitHub secrets manager.
Attributes:
owner: The owner (either individual or organization) of the repository.
repository: Name of the GitHub repository.
"""
owner: str
repository: str
_session: Optional[requests.Session] = None
# Class configuration
FLAVOR: ClassVar[str] = GITHUB_SECRET_MANAGER_FLAVOR
@property
def post_registration_message(self) -> Optional[str]:
"""Info message regarding GitHub API authentication env variables.
Returns:
The info message.
"""
return AUTHENTICATION_CREDENTIALS_MESSAGE
@property
def session(self) -> requests.Session:
"""Session to send requests to the GitHub API.
Returns:
Session to use for GitHub API calls.
Raises:
RuntimeError: If authentication credentials for the GitHub API are
not set.
"""
if not self._session:
session = requests.Session()
github_username = os.getenv(ENV_GITHUB_USERNAME)
authentication_token = os.getenv(ENV_GITHUB_AUTHENTICATION_TOKEN)
if not github_username or not authentication_token:
raise RuntimeError(
"Missing authentication credentials for GitHub secrets "
"manager. " + AUTHENTICATION_CREDENTIALS_MESSAGE
)
session.auth = HTTPBasicAuth(github_username, authentication_token)
session.headers["Accept"] = "application/vnd.github.v3+json"
self._session = session
return self._session
def _send_request(
self, method: str, resource: Optional[str] = None, **kwargs: Any
) -> requests.Response:
"""Sends an HTTP request to the GitHub API.
Args:
method: Method of the HTTP request that should be sent.
resource: Optional resource to which the request should be sent. If
none is given, the default GitHub API secrets endpoint will be
used.
**kwargs: Will be passed to the `requests` library.
Returns:
HTTP response.
# noqa: DAR402
Raises:
HTTPError: If the request failed due to a client or server error.
"""
url = (
f"https://api.github.com/repos/{self.owner}/{self.repository}"
f"/actions/secrets"
)
if resource:
url += resource
response = self.session.request(method=method, url=url, **kwargs)
# Raise an exception in case of a client or server error
response.raise_for_status()
return response
def _encrypt_secret(self, secret_value: str) -> Tuple[str, str]:
"""Encrypts a secret value.
This method first fetches a public key from the GitHub API and then uses
this key to encrypt the secret value. This is needed in order to
register GitHub secrets using the API.
Args:
secret_value: Secret value to encrypt.
Returns:
The encrypted secret value and the key id of the GitHub public key.
"""
from nacl.encoding import Base64Encoder
from nacl.public import PublicKey, SealedBox
response_json = self._send_request("GET", resource="/public-key").json()
public_key = PublicKey(
response_json["key"].encode("utf-8"), Base64Encoder
)
sealed_box = SealedBox(public_key)
encrypted_bytes = sealed_box.encrypt(secret_value.encode("utf-8"))
encrypted_string = base64.b64encode(encrypted_bytes).decode("utf-8")
return encrypted_string, cast(str, response_json["key_id"])
def _has_secret(self, secret_name: str) -> bool:
"""Checks whether a secret exists for the given name.
Args:
secret_name: Name of the secret which should be checked.
Returns:
`True` if a secret with the given name exists, `False` otherwise.
"""
secret_name = _convert_secret_name(secret_name, remove_prefix=True)
return secret_name in self.get_all_secret_keys(include_prefix=False)
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
This method only works when called from within a GitHub Actions
environment.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If a secret with this name doesn't exist.
RuntimeError: If not inside a GitHub Actions environments.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
# Raise a KeyError if the secret doesn't exist. We can do that even
# if we're not inside a GitHub Actions environment
if not self._has_secret(secret_name):
raise KeyError(
f"Unable to find secret '{secret_name}'. Please check the "
"GitHub UI to see if a **Repository** secret called "
f"'{full_secret_name}' exists. (ZenML uses the "
f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
"secrets from other GitHub secrets)"
)
if not inside_github_action_environment():
stack_name = Repository().active_stack_name
commands = [
f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
"zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
"--flavor=local",
"zenml stack update <NEW_STACK_NAME> "
"--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
"zenml stack set <NEW_STACK_NAME>",
f"zenml secret register {secret_name} ...",
]
raise RuntimeError(
"Getting GitHub secrets is only possible within a GitHub "
"Actions workflow. If you need this secret to access "
"stack components (e.g. your metadata store to fetch pipelines "
"during the post-execution workflow) locally, you need to "
"register this secret in a different secrets manager. "
"You can do this by running the following commands: \n\n"
+ "\n".join(commands)
)
# If we're running inside an GitHub Actions environment using the a
# workflow generated by the GitHub Actions orchestrator, all ZenML
# secrets stored in the GitHub secrets manager will be accessible as
# environment variables
secret_value = cast(str, os.getenv(full_secret_name))
secret_dict = json.loads(string_utils.b64_decode(secret_value))
schema_class = SecretSchemaClassRegistry.get_class(
secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
)
secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]
return schema_class(name=secret_name, **secret_content)
def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
"""Get all secret keys.
If we're running inside a GitHub Actions environment, this will return
the names of all environment variables starting with a ZenML internal
prefix. Otherwise, this will return all GitHub **Repository** secrets
created by ZenML.
Args:
include_prefix: Whether or not the internal prefix that is used to
differentiate ZenML secrets from other GitHub secrets should be
included in the returned names.
Returns:
List of all secret keys.
"""
if inside_github_action_environment():
potential_secret_keys = list(os.environ)
else:
logger.info(
"Fetching list of secrets for repository %s/%s",
self.owner,
self.repository,
)
response = self._send_request("GET", params={"per_page": 100})
potential_secret_keys = [
secret_dict["name"]
for secret_dict in response.json()["secrets"]
]
keys = [
_convert_secret_name(key, remove_prefix=not include_prefix)
for key in potential_secret_keys
if key.startswith(GITHUB_SECRET_PREFIX)
]
return keys
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If a secret with this name already exists.
"""
if self._has_secret(secret.name):
raise SecretExistsError(
f"A secret with name '{secret.name}' already exists for this "
"GitHub repository. If you want to register a new value for "
f"this secret, please run `zenml secret delete {secret.name}` "
f"followed by `zenml secret register {secret.name} ...`."
)
secret_dict = {
SECRET_SCHEMA_DICT_KEY: secret.TYPE,
SECRET_CONTENT_DICT_KEY: secret.content,
}
secret_value = string_utils.b64_encode(json.dumps(secret_dict))
encrypted_secret, public_key_id = self._encrypt_secret(
secret_value=secret_value
)
body = {
"encrypted_value": encrypted_secret,
"key_id": public_key_id,
}
full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
self._send_request("PUT", resource=f"/{full_secret_name}", json=body)
def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
NotImplementedError: Always, as this functionality is not possible
using GitHub secrets which doesn't allow us to retrieve the
secret values outside of a GitHub Actions environment.
"""
raise NotImplementedError(
"Updating secrets is not possible with the GitHub secrets manager "
"as it is not possible to retrieve GitHub secrets values outside "
"of a GitHub Actions environment."
)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
self._send_request("DELETE", resource=f"/{full_secret_name}")
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
for secret_name in self.get_all_secret_keys(include_prefix=False):
self.delete_secret(secret_name=secret_name)
post_registration_message: Optional[str]
property
readonly
Info message regarding GitHub API authentication env variables.
Returns:
Type | Description |
---|---|
Optional[str] |
The info message. |
session: Session
property
readonly
Session to send requests to the GitHub API.
Returns:
Type | Description |
---|---|
Session |
Session to use for GitHub API calls. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If authentication credentials for the GitHub API are not set. |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
for secret_name in self.get_all_secret_keys(include_prefix=False):
self.delete_secret(secret_name=secret_name)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to delete. |
required |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
self._send_request("DELETE", resource=f"/{full_secret_name}")
get_all_secret_keys(self, include_prefix=False)
Get all secret keys.
If we're running inside a GitHub Actions environment, this will return the names of all environment variables starting with a ZenML internal prefix. Otherwise, this will return all GitHub Repository secrets created by ZenML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
include_prefix |
bool |
Whether or not the internal prefix that is used to differentiate ZenML secrets from other GitHub secrets should be included in the returned names. |
False |
Returns:
Type | Description |
---|---|
List[str] |
List of all secret keys. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
"""Get all secret keys.
If we're running inside a GitHub Actions environment, this will return
the names of all environment variables starting with a ZenML internal
prefix. Otherwise, this will return all GitHub **Repository** secrets
created by ZenML.
Args:
include_prefix: Whether or not the internal prefix that is used to
differentiate ZenML secrets from other GitHub secrets should be
included in the returned names.
Returns:
List of all secret keys.
"""
if inside_github_action_environment():
potential_secret_keys = list(os.environ)
else:
logger.info(
"Fetching list of secrets for repository %s/%s",
self.owner,
self.repository,
)
response = self._send_request("GET", params={"per_page": 100})
potential_secret_keys = [
secret_dict["name"]
for secret_dict in response.json()["secrets"]
]
keys = [
_convert_secret_name(key, remove_prefix=not include_prefix)
for key in potential_secret_keys
if key.startswith(GITHUB_SECRET_PREFIX)
]
return keys
get_secret(self, secret_name)
Gets the value of a secret.
This method only works when called from within a GitHub Actions environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to get. |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
If a secret with this name doesn't exist. |
RuntimeError |
If not inside a GitHub Actions environments. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
This method only works when called from within a GitHub Actions
environment.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If a secret with this name doesn't exist.
RuntimeError: If not inside a GitHub Actions environments.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
# Raise a KeyError if the secret doesn't exist. We can do that even
# if we're not inside a GitHub Actions environment
if not self._has_secret(secret_name):
raise KeyError(
f"Unable to find secret '{secret_name}'. Please check the "
"GitHub UI to see if a **Repository** secret called "
f"'{full_secret_name}' exists. (ZenML uses the "
f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
"secrets from other GitHub secrets)"
)
if not inside_github_action_environment():
stack_name = Repository().active_stack_name
commands = [
f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
"zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
"--flavor=local",
"zenml stack update <NEW_STACK_NAME> "
"--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
"zenml stack set <NEW_STACK_NAME>",
f"zenml secret register {secret_name} ...",
]
raise RuntimeError(
"Getting GitHub secrets is only possible within a GitHub "
"Actions workflow. If you need this secret to access "
"stack components (e.g. your metadata store to fetch pipelines "
"during the post-execution workflow) locally, you need to "
"register this secret in a different secrets manager. "
"You can do this by running the following commands: \n\n"
+ "\n".join(commands)
)
# If we're running inside an GitHub Actions environment using the a
# workflow generated by the GitHub Actions orchestrator, all ZenML
# secrets stored in the GitHub secrets manager will be accessible as
# environment variables
secret_value = cast(str, os.getenv(full_secret_name))
secret_dict = json.loads(string_utils.b64_decode(secret_value))
schema_class = SecretSchemaClassRegistry.get_class(
secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
)
secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]
return schema_class(name=secret_name, **secret_content)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to register. |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
If a secret with this name already exists. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If a secret with this name already exists.
"""
if self._has_secret(secret.name):
raise SecretExistsError(
f"A secret with name '{secret.name}' already exists for this "
"GitHub repository. If you want to register a new value for "
f"this secret, please run `zenml secret delete {secret.name}` "
f"followed by `zenml secret register {secret.name} ...`."
)
secret_dict = {
SECRET_SCHEMA_DICT_KEY: secret.TYPE,
SECRET_CONTENT_DICT_KEY: secret.content,
}
secret_value = string_utils.b64_encode(json.dumps(secret_dict))
encrypted_secret, public_key_id = self._encrypt_secret(
secret_value=secret_value
)
body = {
"encrypted_value": encrypted_secret,
"key_id": public_key_id,
}
full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
self._send_request("PUT", resource=f"/{full_secret_name}", json=body)
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to update. |
required |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
Always, as this functionality is not possible using GitHub secrets which doesn't allow us to retrieve the secret values outside of a GitHub Actions environment. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
NotImplementedError: Always, as this functionality is not possible
using GitHub secrets which doesn't allow us to retrieve the
secret values outside of a GitHub Actions environment.
"""
raise NotImplementedError(
"Updating secrets is not possible with the GitHub secrets manager "
"as it is not possible to retrieve GitHub secrets values outside "
"of a GitHub Actions environment."
)
inside_github_action_environment()
Returns if the current code is executing in a GitHub Actions environment.
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def inside_github_action_environment() -> bool:
"""Returns if the current code is executing in a GitHub Actions environment.
Returns:
`True` if running in a GitHub Actions environment, `False` otherwise.
"""
return os.getenv(ENV_IN_GITHUB_ACTIONS) == "true"
graphviz
special
Initialization of the Graphviz integration.
GraphvizIntegration (Integration)
Definition of Graphviz integration for ZenML.
Source code in zenml/integrations/graphviz/__init__.py
class GraphvizIntegration(Integration):
"""Definition of Graphviz integration for ZenML."""
NAME = GRAPHVIZ
REQUIREMENTS = ["graphviz>=0.17"]
SYSTEM_REQUIREMENTS = {"graphviz": "dot"}
visualizers
special
Initialization of Graphviz visualizers.
pipeline_run_dag_visualizer
Implementation of the Graphviz pipeline run DAG visualizer.
PipelineRunDagVisualizer (BasePipelineRunVisualizer)
Visualize the lineage of runs in a pipeline.
Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
class PipelineRunDagVisualizer(BasePipelineRunVisualizer):
"""Visualize the lineage of runs in a pipeline."""
ARTIFACT_DEFAULT_COLOR = "blue"
ARTIFACT_CACHED_COLOR = "green"
ARTIFACT_SHAPE = "box"
ARTIFACT_PREFIX = "artifact_"
STEP_COLOR = "#431D93"
STEP_SHAPE = "ellipse"
STEP_PREFIX = "step_"
FONT = "Roboto"
@abstractmethod
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
"""Creates a pipeline lineage diagram using graphviz.
Args:
object: The pipeline run view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A graphviz digraph object.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
dot = graphviz.Digraph(comment=object.name)
# link the steps together
for step in object.steps:
# add each step as a node
dot.node(
self.STEP_PREFIX + str(step.id),
step.entrypoint_name,
shape=self.STEP_SHAPE,
)
# for each parent of a step, add an edge
for artifact_name, artifact in step.outputs.items():
dot.node(
self.ARTIFACT_PREFIX + str(artifact.id),
f"{artifact_name} \n" f"({artifact._data_type})",
shape=self.ARTIFACT_SHAPE,
)
dot.edge(
self.STEP_PREFIX + str(step.id),
self.ARTIFACT_PREFIX + str(artifact.id),
)
for artifact_name, artifact in step.inputs.items():
dot.edge(
self.ARTIFACT_PREFIX + str(artifact.id),
self.STEP_PREFIX + str(step.id),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
dot.render(filename=f.name, format="png", view=True, cleanup=True)
return dot
visualize(self, object, *args, **kwargs)
Creates a pipeline lineage diagram using graphviz.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineRunView |
The pipeline run view to visualize. |
required |
*args |
Any |
Additional arguments to pass to the visualization. |
() |
**kwargs |
Any |
Additional keyword arguments to pass to the visualization. |
{} |
Returns:
Type | Description |
---|---|
Digraph |
A graphviz digraph object. |
Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
@abstractmethod
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
"""Creates a pipeline lineage diagram using graphviz.
Args:
object: The pipeline run view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A graphviz digraph object.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
dot = graphviz.Digraph(comment=object.name)
# link the steps together
for step in object.steps:
# add each step as a node
dot.node(
self.STEP_PREFIX + str(step.id),
step.entrypoint_name,
shape=self.STEP_SHAPE,
)
# for each parent of a step, add an edge
for artifact_name, artifact in step.outputs.items():
dot.node(
self.ARTIFACT_PREFIX + str(artifact.id),
f"{artifact_name} \n" f"({artifact._data_type})",
shape=self.ARTIFACT_SHAPE,
)
dot.edge(
self.STEP_PREFIX + str(step.id),
self.ARTIFACT_PREFIX + str(artifact.id),
)
for artifact_name, artifact in step.inputs.items():
dot.edge(
self.ARTIFACT_PREFIX + str(artifact.id),
self.STEP_PREFIX + str(step.id),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
dot.render(filename=f.name, format="png", view=True, cleanup=True)
return dot
great_expectations
special
Great Expectation integration for ZenML.
The Great Expectations integration enables you to use Great Expectations as a way of profiling and validating your data.
GreatExpectationsIntegration (Integration)
Definition of Great Expectations integration for ZenML.
Source code in zenml/integrations/great_expectations/__init__.py
class GreatExpectationsIntegration(Integration):
"""Definition of Great Expectations integration for ZenML."""
NAME = GREAT_EXPECTATIONS
REQUIREMENTS = [
"great-expectations~=0.15.11",
]
@staticmethod
def activate() -> None:
"""Activate the Great Expectations integration."""
from zenml.integrations.great_expectations import materializers # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.great_expectations.data_validators.GreatExpectationsDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
activate()
staticmethod
Activate the Great Expectations integration.
Source code in zenml/integrations/great_expectations/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Great Expectations integration."""
from zenml.integrations.great_expectations import materializers # noqa
flavors()
classmethod
Declare the stack component flavors for the Great Expectations integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/great_expectations/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.great_expectations.data_validators.GreatExpectationsDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
data_validators
special
Initialization of the Great Expectations data validator for ZenML.
ge_data_validator
Implementation of the Great Expectations data validator.
GreatExpectationsDataValidator (BaseDataValidator)
pydantic-model
Great Expectations data validator stack component.
Attributes:
Name | Type | Description |
---|---|---|
context_root_dir |
Optional[str] |
location of an already initialized Great Expectations data context. If configured, the data validator will only be usable with local orchestrators. |
context_config |
Optional[Dict[str, Any]] |
in-line Great Expectations data context configuration. |
configure_zenml_stores |
bool |
if set, ZenML will automatically configure
stores that use the Artifact Store as a backend. If neither
|
configure_local_docs |
bool |
configure a local data docs site where Great Expectations docs are generated and can be visualized locally. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
class GreatExpectationsDataValidator(BaseDataValidator):
"""Great Expectations data validator stack component.
Attributes:
context_root_dir: location of an already initialized Great Expectations
data context. If configured, the data validator will only be usable
with local orchestrators.
context_config: in-line Great Expectations data context configuration.
configure_zenml_stores: if set, ZenML will automatically configure
stores that use the Artifact Store as a backend. If neither
`context_root_dir` nor `context_config` are set, this is the default
behavior.
configure_local_docs: configure a local data docs site where Great
Expectations docs are generated and can be visualized locally.
"""
context_root_dir: Optional[str] = None
context_config: Optional[Dict[str, Any]] = None
configure_zenml_stores: bool = False
configure_local_docs: bool = True
_context: BaseDataContext = None
# Class Configuration
FLAVOR: ClassVar[str] = GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR
@validator("context_root_dir")
def _ensure_valid_context_root_dir(
cls, context_root_dir: Optional[str] = None
) -> Optional[str]:
"""Ensures that the root directory is an absolute path and points to an existing path.
Args:
context_root_dir: The context_root_dir value to validate.
Returns:
The context_root_dir if it is valid.
Raises:
ValueError: If the context_root_dir is not valid.
"""
if context_root_dir:
context_root_dir = os.path.abspath(context_root_dir)
if not fileio.exists(context_root_dir):
raise ValueError(
f"The Great Expectations context_root_dir value doesn't "
f"point to an existing data context path: {context_root_dir}"
)
return context_root_dir
@root_validator(pre=True)
def _convert_context_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Converts context_config from JSON/YAML string format to a dict.
Args:
values: Values passed to the object constructor
Returns:
Values passed to the object constructor
Raises:
ValueError: If the context_config value is not a valid JSON/YAML or
if the GE configuration extracted from it fails GE validation.
"""
context_config = values.get("context_config")
if context_config and not isinstance(context_config, dict):
try:
context_config_dict = yaml.safe_load(context_config)
except yaml.parser.ParserError as e:
raise ValueError(
f"Malformed `context_config` value. Only JSON and YAML formats "
f"are supported: {str(e)}"
)
try:
context_config = DataContextConfig(**context_config_dict)
BaseDataContext(project_config=context_config)
except Exception as e:
raise ValueError(f"Invalid `context_config` value: {str(e)}")
values["context_config"] = context_config_dict
return values
@classmethod
def get_data_context(cls) -> BaseDataContext:
"""Get the Great Expectations data context managed by ZenML.
Call this method to retrieve the data context managed by ZenML
through the active Great Expectations data validator stack component.
Returns:
A Great Expectations data context managed by ZenML as configured
through the active data validator stack component.
"""
data_validator = cast(
"GreatExpectationsDataValidator", cls.get_active_data_validator()
)
return data_validator.data_context
@property
def local_path(self) -> Optional[str]:
"""Return a local path where this component stores information.
If an existing local GE data context is used, it is
interpreted as a local path that needs to be accessible in
all runtime environments.
Returns:
The local path where this component stores information.
"""
return self.context_root_dir
def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
"""Generate a Great Expectations store configuration.
Args:
class_name: The store class name
prefix: The path prefix for the ZenML store configuration
Returns:
A dictionary with the GE store configuration.
"""
return {
"class_name": class_name,
"store_backend": {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.uuid)}/{prefix}",
},
}
def get_data_docs_config(
self, prefix: str, local: bool = False
) -> Dict[str, Any]:
"""Generate Great Expectations data docs configuration.
Args:
prefix: The path prefix for the ZenML data docs configuration
local: Whether the data docs site is local or remote.
Returns:
A dictionary with the GE data docs site configuration.
"""
if local:
store_backend = {
"class_name": "TupleFilesystemStoreBackend",
"base_directory": f"{self.root_directory}/{prefix}",
}
else:
store_backend = {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.uuid)}/{prefix}",
}
return {
"class_name": "SiteBuilder",
"store_backend": store_backend,
"site_index_builder": {
"class_name": "DefaultSiteIndexBuilder",
},
}
@property
def data_context(self) -> BaseDataContext:
"""Returns the Great Expectations data context configured for this component.
Returns:
The Great Expectations data context configured for this component.
"""
if not self._context:
expectations_store_name = "zenml_expectations_store"
validations_store_name = "zenml_validations_store"
checkpoint_store_name = "zenml_checkpoint_store"
profiler_store_name = "zenml_profiler_store"
evaluation_parameter_store_name = "evaluation_parameter_store"
zenml_context_config = dict(
stores={
expectations_store_name: self.get_store_config(
"ExpectationsStore", "expectations"
),
validations_store_name: self.get_store_config(
"ValidationsStore", "validations"
),
checkpoint_store_name: self.get_store_config(
"CheckpointStore", "checkpoints"
),
profiler_store_name: self.get_store_config(
"ProfilerStore", "profilers"
),
evaluation_parameter_store_name: {
"class_name": "EvaluationParameterStore"
},
},
expectations_store_name=expectations_store_name,
validations_store_name=validations_store_name,
checkpoint_store_name=checkpoint_store_name,
profiler_store_name=profiler_store_name,
evaluation_parameter_store_name=evaluation_parameter_store_name,
data_docs_sites={
"zenml_artifact_store": self.get_data_docs_config(
"data_docs"
)
},
)
configure_zenml_stores = self.configure_zenml_stores
if self.context_root_dir:
# initialize the local data context, if a local path was
# configured
self._context = DataContext(self.context_root_dir)
else:
# create an in-memory data context configuration that is not
# backed by a local YAML file (see https://docs.greatexpectations.io/docs/guides/setup/configuring_data_contexts/how_to_instantiate_a_data_context_without_a_yml_file/).
if self.context_config:
context_config = DataContextConfig(**self.context_config)
else:
context_config = DataContextConfig(**zenml_context_config)
# skip adding the stores after initialization, as they are
# already baked in the initial configuration
configure_zenml_stores = False
self._context = BaseDataContext(project_config=context_config)
if configure_zenml_stores:
self._context.config.expectations_store_name = (
expectations_store_name
)
self._context.config.validations_store_name = (
validations_store_name
)
self._context.config.checkpoint_store_name = (
checkpoint_store_name
)
self._context.config.profiler_store_name = profiler_store_name
self._context.config.evaluation_parameter_store_name = (
evaluation_parameter_store_name
)
for store_name, store_config in zenml_context_config[ # type: ignore[attr-defined]
"stores"
].items():
self._context.add_store(
store_name=store_name,
store_config=store_config,
)
for site_name, site_config in zenml_context_config[ # type: ignore[attr-defined]
"data_docs_sites"
].items():
self._context.config.data_docs_sites[
site_name
] = site_config
if self.configure_local_docs:
repo = Repository(skip_repository_check=True) # type: ignore[call-arg]
artifact_store = repo.active_stack.artifact_store
if artifact_store.FLAVOR != "local":
self._context.config.data_docs_sites[
"zenml_local"
] = self.get_data_docs_config("data_docs", local=True)
return self._context
@property
def root_directory(self) -> str:
"""Returns path to the root directory for all local files concerning this data validator.
Returns:
Path to the root directory.
"""
path = os.path.join(
io_utils.get_global_config_directory(),
self.FLAVOR,
str(self.uuid),
)
if not os.path.exists(path):
fileio.makedirs(path)
return path
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
profile_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
profiler_kwargs: Optional[Dict[str, Any]] = None,
overwrite_existing_suite: bool = True,
**kwargs: Any,
) -> ExpectationSuite:
"""Infer a Great Expectation Expectation Suite from a given dataset.
This Great Expectations specific data profiling method implementation
builds an Expectation Suite automatically by running a
UserConfigurableProfiler on an input dataset [as covered in the official
GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
Args:
dataset: The dataset from which the expectation suite will be
inferred.
comparison_dataset: Optional dataset used to generate data
comparison (i.e. data drift) profiles. Not supported by the
Great Expectation data validator.
profile_list: Optional list identifying the categories of data
profiles to be generated. Not supported by the Great Expectation
data validator.
expectation_suite_name: The name of the expectation suite to create
or update. If not supplied, a unique name will be generated from
the current pipeline and step name, if running in the context of
a pipeline step.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
profiler_kwargs: A dictionary of custom keyword arguments to pass to
the profiler.
overwrite_existing_suite: Whether to overwrite an existing
expectation suite, if one exists with that name.
kwargs: Additional keyword arguments (unused).
Returns:
The inferred Expectation Suite.
Raises:
ValueError: if an `expectation_suite_name` value is not supplied and
a name for the expectation suite cannot be generated from the
current step name and pipeline name.
"""
context = self.data_context
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data profiling. Silently ignoring the supplied dataset "
)
if not expectation_suite_name:
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
pipeline_name = step_env.pipeline_name
step_name = step_env.step_name
expectation_suite_name = f"{pipeline_name}_{step_name}"
except KeyError:
raise ValueError(
"A expectation suite name is required when not running in "
"the context of a pipeline step."
)
suite_exists = False
if context.expectations_store.has_key( # noqa
ExpectationSuiteIdentifier(expectation_suite_name)
):
suite_exists = True
suite = context.get_expectation_suite(expectation_suite_name)
if not overwrite_existing_suite:
logger.info(
f"Expectation Suite `{expectation_suite_name}` "
f"already exists and `overwrite_existing_suite` is not set "
f"in the step configuration. Skipping re-running the "
f"profiler."
)
return suite
batch_request = create_batch_request(context, dataset, data_asset_name)
try:
if suite_exists:
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=expectation_suite_name,
)
else:
validator = context.get_validator(
batch_request=batch_request,
create_expectation_suite_with_name=expectation_suite_name,
)
profiler = UserConfigurableProfiler(
profile_dataset=validator, **profiler_kwargs
)
suite = profiler.build_suite()
context.save_expectation_suite(
expectation_suite=suite,
expectation_suite_name=expectation_suite_name,
)
context.build_data_docs()
finally:
context.delete_datasource(batch_request.datasource_name)
return suite
def data_validation(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
action_list: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> CheckpointResult:
"""Great Expectations data validation.
This Great Expectations specific data validation method
implementation validates an input dataset against an Expectation Suite
(the GE definition of a profile) [as covered in the official GE
documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
Args:
dataset: The dataset to validate.
comparison_dataset: Optional dataset used to run data
comparison (i.e. data drift) checks. Not supported by the
Great Expectation data validator.
check_list: Optional list identifying the data validation checks to
be performed. Not supported by the Great Expectations data
validator.
expectation_suite_name: The name of the expectation suite to use to
validate the dataset. A value must be provided.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run after
the validation check.
kwargs: Additional keyword arguments (unused).
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
ValueError: if the `expectation_suite_name` argument is omitted.
"""
if not expectation_suite_name:
raise ValueError("Missing expectation_suite_name argument value.")
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data validation. Silently ignoring the supplied dataset "
)
try:
# get pipeline name, step name and run id
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
except KeyError:
# if not running inside a pipeline step, use random values
run_id = f"pipeline_{random_str(5)}"
step_name = f"step_{random_str(5)}"
context = self.data_context
checkpoint_name = f"{run_id}_{step_name}"
batch_request = create_batch_request(context, dataset, data_asset_name)
action_list = action_list or [
{
"name": "store_validation_result",
"action": {"class_name": "StoreValidationResultAction"},
},
{
"name": "store_evaluation_params",
"action": {"class_name": "StoreEvaluationParametersAction"},
},
{
"name": "update_data_docs",
"action": {"class_name": "UpdateDataDocsAction"},
},
]
checkpoint_config = {
"name": checkpoint_name,
"run_name_template": f"{run_id}",
"config_version": 1,
"class_name": "Checkpoint",
"expectation_suite_name": expectation_suite_name,
"action_list": action_list,
}
context.add_checkpoint(**checkpoint_config)
try:
results = context.run_checkpoint(
checkpoint_name=checkpoint_name,
validations=[{"batch_request": batch_request}],
)
finally:
context.delete_datasource(batch_request.datasource_name)
context.delete_checkpoint(checkpoint_name)
return results
data_context: BaseDataContext
property
readonly
Returns the Great Expectations data context configured for this component.
Returns:
Type | Description |
---|---|
BaseDataContext |
The Great Expectations data context configured for this component. |
local_path: Optional[str]
property
readonly
Return a local path where this component stores information.
If an existing local GE data context is used, it is interpreted as a local path that needs to be accessible in all runtime environments.
Returns:
Type | Description |
---|---|
Optional[str] |
The local path where this component stores information. |
root_directory: str
property
readonly
Returns path to the root directory for all local files concerning this data validator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, expectation_suite_name=None, data_asset_name=None, profiler_kwargs=None, overwrite_existing_suite=True, **kwargs)
Infer a Great Expectation Expectation Suite from a given dataset.
This Great Expectations specific data profiling method implementation builds an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset from which the expectation suite will be inferred. |
required |
comparison_dataset |
Optional[Any] |
Optional dataset used to generate data comparison (i.e. data drift) profiles. Not supported by the Great Expectation data validator. |
None |
profile_list |
Optional[Sequence[str]] |
Optional list identifying the categories of data profiles to be generated. Not supported by the Great Expectation data validator. |
None |
expectation_suite_name |
Optional[str] |
The name of the expectation suite to create or update. If not supplied, a unique name will be generated from the current pipeline and step name, if running in the context of a pipeline step. |
None |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
None |
profiler_kwargs |
Optional[Dict[str, Any]] |
A dictionary of custom keyword arguments to pass to the profiler. |
None |
overwrite_existing_suite |
bool |
Whether to overwrite an existing expectation suite, if one exists with that name. |
True |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
ExpectationSuite |
The inferred Expectation Suite. |
Exceptions:
Type | Description |
---|---|
ValueError |
if an |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
profile_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
profiler_kwargs: Optional[Dict[str, Any]] = None,
overwrite_existing_suite: bool = True,
**kwargs: Any,
) -> ExpectationSuite:
"""Infer a Great Expectation Expectation Suite from a given dataset.
This Great Expectations specific data profiling method implementation
builds an Expectation Suite automatically by running a
UserConfigurableProfiler on an input dataset [as covered in the official
GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
Args:
dataset: The dataset from which the expectation suite will be
inferred.
comparison_dataset: Optional dataset used to generate data
comparison (i.e. data drift) profiles. Not supported by the
Great Expectation data validator.
profile_list: Optional list identifying the categories of data
profiles to be generated. Not supported by the Great Expectation
data validator.
expectation_suite_name: The name of the expectation suite to create
or update. If not supplied, a unique name will be generated from
the current pipeline and step name, if running in the context of
a pipeline step.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
profiler_kwargs: A dictionary of custom keyword arguments to pass to
the profiler.
overwrite_existing_suite: Whether to overwrite an existing
expectation suite, if one exists with that name.
kwargs: Additional keyword arguments (unused).
Returns:
The inferred Expectation Suite.
Raises:
ValueError: if an `expectation_suite_name` value is not supplied and
a name for the expectation suite cannot be generated from the
current step name and pipeline name.
"""
context = self.data_context
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data profiling. Silently ignoring the supplied dataset "
)
if not expectation_suite_name:
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
pipeline_name = step_env.pipeline_name
step_name = step_env.step_name
expectation_suite_name = f"{pipeline_name}_{step_name}"
except KeyError:
raise ValueError(
"A expectation suite name is required when not running in "
"the context of a pipeline step."
)
suite_exists = False
if context.expectations_store.has_key( # noqa
ExpectationSuiteIdentifier(expectation_suite_name)
):
suite_exists = True
suite = context.get_expectation_suite(expectation_suite_name)
if not overwrite_existing_suite:
logger.info(
f"Expectation Suite `{expectation_suite_name}` "
f"already exists and `overwrite_existing_suite` is not set "
f"in the step configuration. Skipping re-running the "
f"profiler."
)
return suite
batch_request = create_batch_request(context, dataset, data_asset_name)
try:
if suite_exists:
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=expectation_suite_name,
)
else:
validator = context.get_validator(
batch_request=batch_request,
create_expectation_suite_with_name=expectation_suite_name,
)
profiler = UserConfigurableProfiler(
profile_dataset=validator, **profiler_kwargs
)
suite = profiler.build_suite()
context.save_expectation_suite(
expectation_suite=suite,
expectation_suite_name=expectation_suite_name,
)
context.build_data_docs()
finally:
context.delete_datasource(batch_request.datasource_name)
return suite
data_validation(self, dataset, comparison_dataset=None, check_list=None, expectation_suite_name=None, data_asset_name=None, action_list=None, **kwargs)
Great Expectations data validation.
This Great Expectations specific data validation method implementation validates an input dataset against an Expectation Suite (the GE definition of a profile) as covered in the official GE documentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset to validate. |
required |
comparison_dataset |
Optional[Any] |
Optional dataset used to run data comparison (i.e. data drift) checks. Not supported by the Great Expectation data validator. |
None |
check_list |
Optional[Sequence[str]] |
Optional list identifying the data validation checks to be performed. Not supported by the Great Expectations data validator. |
None |
expectation_suite_name |
Optional[str] |
The name of the expectation suite to use to validate the dataset. A value must be provided. |
None |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
None |
action_list |
Optional[List[Dict[str, Any]]] |
A list of additional Great Expectations actions to run after the validation check. |
None |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
CheckpointResult |
The Great Expectations validation (checkpoint) result. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def data_validation(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
action_list: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> CheckpointResult:
"""Great Expectations data validation.
This Great Expectations specific data validation method
implementation validates an input dataset against an Expectation Suite
(the GE definition of a profile) [as covered in the official GE
documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
Args:
dataset: The dataset to validate.
comparison_dataset: Optional dataset used to run data
comparison (i.e. data drift) checks. Not supported by the
Great Expectation data validator.
check_list: Optional list identifying the data validation checks to
be performed. Not supported by the Great Expectations data
validator.
expectation_suite_name: The name of the expectation suite to use to
validate the dataset. A value must be provided.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run after
the validation check.
kwargs: Additional keyword arguments (unused).
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
ValueError: if the `expectation_suite_name` argument is omitted.
"""
if not expectation_suite_name:
raise ValueError("Missing expectation_suite_name argument value.")
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data validation. Silently ignoring the supplied dataset "
)
try:
# get pipeline name, step name and run id
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
except KeyError:
# if not running inside a pipeline step, use random values
run_id = f"pipeline_{random_str(5)}"
step_name = f"step_{random_str(5)}"
context = self.data_context
checkpoint_name = f"{run_id}_{step_name}"
batch_request = create_batch_request(context, dataset, data_asset_name)
action_list = action_list or [
{
"name": "store_validation_result",
"action": {"class_name": "StoreValidationResultAction"},
},
{
"name": "store_evaluation_params",
"action": {"class_name": "StoreEvaluationParametersAction"},
},
{
"name": "update_data_docs",
"action": {"class_name": "UpdateDataDocsAction"},
},
]
checkpoint_config = {
"name": checkpoint_name,
"run_name_template": f"{run_id}",
"config_version": 1,
"class_name": "Checkpoint",
"expectation_suite_name": expectation_suite_name,
"action_list": action_list,
}
context.add_checkpoint(**checkpoint_config)
try:
results = context.run_checkpoint(
checkpoint_name=checkpoint_name,
validations=[{"batch_request": batch_request}],
)
finally:
context.delete_datasource(batch_request.datasource_name)
context.delete_checkpoint(checkpoint_name)
return results
get_data_context()
classmethod
Get the Great Expectations data context managed by ZenML.
Call this method to retrieve the data context managed by ZenML through the active Great Expectations data validator stack component.
Returns:
Type | Description |
---|---|
BaseDataContext |
A Great Expectations data context managed by ZenML as configured through the active data validator stack component. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
@classmethod
def get_data_context(cls) -> BaseDataContext:
"""Get the Great Expectations data context managed by ZenML.
Call this method to retrieve the data context managed by ZenML
through the active Great Expectations data validator stack component.
Returns:
A Great Expectations data context managed by ZenML as configured
through the active data validator stack component.
"""
data_validator = cast(
"GreatExpectationsDataValidator", cls.get_active_data_validator()
)
return data_validator.data_context
get_data_docs_config(self, prefix, local=False)
Generate Great Expectations data docs configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prefix |
str |
The path prefix for the ZenML data docs configuration |
required |
local |
bool |
Whether the data docs site is local or remote. |
False |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the GE data docs site configuration. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_data_docs_config(
self, prefix: str, local: bool = False
) -> Dict[str, Any]:
"""Generate Great Expectations data docs configuration.
Args:
prefix: The path prefix for the ZenML data docs configuration
local: Whether the data docs site is local or remote.
Returns:
A dictionary with the GE data docs site configuration.
"""
if local:
store_backend = {
"class_name": "TupleFilesystemStoreBackend",
"base_directory": f"{self.root_directory}/{prefix}",
}
else:
store_backend = {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.uuid)}/{prefix}",
}
return {
"class_name": "SiteBuilder",
"store_backend": store_backend,
"site_index_builder": {
"class_name": "DefaultSiteIndexBuilder",
},
}
get_store_config(self, class_name, prefix)
Generate a Great Expectations store configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_name |
str |
The store class name |
required |
prefix |
str |
The path prefix for the ZenML store configuration |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the GE store configuration. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
"""Generate a Great Expectations store configuration.
Args:
class_name: The store class name
prefix: The path prefix for the ZenML store configuration
Returns:
A dictionary with the GE store configuration.
"""
return {
"class_name": class_name,
"store_backend": {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.uuid)}/{prefix}",
},
}
ge_store_backend
Great Expectations store plugin for ZenML.
ZenMLArtifactStoreBackend (TupleStoreBackend)
Great Expectations store backend that uses the active ZenML Artifact Store as a store.
Source code in zenml/integrations/great_expectations/ge_store_backend.py
class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
"""Great Expectations store backend that uses the active ZenML Artifact Store as a store."""
def __init__(
self,
prefix: str = "",
**kwargs: Any,
) -> None:
"""Create a Great Expectations ZenML store backend instance.
Args:
prefix: Subpath prefix to use for this store backend.
kwargs: Additional keyword arguments passed by the Great Expectations
core. These are transparently passed to the `TupleStoreBackend`
constructor.
"""
super().__init__(**kwargs)
repo = Repository(skip_repository_check=True) # type: ignore[call-arg]
artifact_store = repo.active_stack.artifact_store
self.root_path = os.path.join(artifact_store.path, "great_expectations")
# extract the protocol used in the artifact store root path
protocols = [
scheme
for scheme in artifact_store.SUPPORTED_SCHEMES
if self.root_path.startswith(scheme)
]
if protocols:
self.proto = protocols[0]
else:
self.proto = ""
if prefix:
if self.platform_specific_separator:
prefix = prefix.strip(os.sep)
prefix = prefix.strip("/")
self.prefix = prefix
# Initialize with store_backend_id if not part of an HTMLSiteStore
if not self._suppress_store_backend_id:
_ = self.store_backend_id
self._config = {
"prefix": prefix,
"module_name": self.__class__.__module__,
"class_name": self.__class__.__name__,
}
self._config.update(kwargs)
filter_properties_dict(
properties=self._config, clean_falsy=True, inplace=True
)
def _build_object_path(
self, key: Tuple[str, ...], is_prefix: bool = False
) -> str:
"""Build a filepath corresponding to an object key.
Args:
key: Great Expectation object key.
is_prefix: If True, the key will be interpreted as a prefix instead
of a full key identifier.
Returns:
The file path pointing to where the object is stored.
"""
if not isinstance(key, tuple):
key = key.to_tuple() # type: ignore[attr-defined]
if not is_prefix:
object_relative_path = self._convert_key_to_filepath(key)
elif key:
object_relative_path = os.path.join(*key)
else:
object_relative_path = ""
if self.prefix:
object_key = os.path.join(self.prefix, object_relative_path)
else:
object_key = object_relative_path
return os.path.join(self.root_path, object_key)
def _get(self, key: Tuple[str, ...]) -> str:
"""Get the value of an object from the store.
Args:
key: object key identifier.
Raises:
InvalidKeyError: if the key doesn't point to an existing object.
Returns:
str: the object's contents
"""
filepath: str = self._build_object_path(key)
if fileio.exists(filepath):
contents = io_utils.read_file_contents_as_string(filepath).rstrip(
"\n"
)
else:
raise InvalidKeyError(
f"Unable to retrieve object from {self.__class__.__name__} with "
f"the following Key: {str(filepath)}"
)
return contents
def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str:
"""Set the value of an object in the store.
Args:
key: object key identifier.
value: object value to set.
kwargs: additional keyword arguments (ignored).
Returns:
The file path where the object was stored.
"""
filepath: str = self._build_object_path(key)
if not io_utils.is_remote(filepath):
parent_dir = str(Path(filepath).parent)
os.makedirs(parent_dir, exist_ok=True)
with fileio.open(filepath, "wb") as outfile:
if isinstance(value, str):
outfile.write(value.encode("utf-8"))
else:
outfile.write(value)
return filepath
def _move(
self,
source_key: Tuple[str, ...],
dest_key: Tuple[str, ...],
**kwargs: Any,
) -> None:
"""Associate an object with a different key in the store.
Args:
source_key: current object key identifier.
dest_key: new object key identifier.
kwargs: additional keyword arguments (ignored).
"""
source_path = self._build_object_path(source_key)
dest_path = self._build_object_path(dest_key)
if fileio.exists(source_path):
if not io_utils.is_remote(dest_path):
parent_dir = str(Path(dest_path).parent)
os.makedirs(parent_dir, exist_ok=True)
fileio.rename(source_path, dest_path, overwrite=True)
def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
"""List the keys of all objects identified by a partial key.
Args:
prefix: partial object key identifier.
Returns:
List of keys identifying all objects present in the store that
match the input partial key.
"""
key_list = []
list_path = self._build_object_path(prefix, is_prefix=True)
root_path = self._build_object_path(tuple(), is_prefix=True)
for root, dirs, files in fileio.walk(list_path):
for file_ in files:
filepath = os.path.relpath(
os.path.join(str(root), str(file_)), root_path
)
if self.filepath_prefix and not filepath.startswith(
self.filepath_prefix
):
continue
elif self.filepath_suffix and not filepath.endswith(
self.filepath_suffix
):
continue
key = self._convert_filepath_to_key(filepath)
if key and not self.is_ignored_key(key):
key_list.append(key)
return key_list
def remove_key(self, key: Tuple[str, ...]) -> bool:
"""Delete an object from the store.
Args:
key: object key identifier.
Returns:
True if the object existed in the store and was removed, otherwise
False.
"""
filepath: str = self._build_object_path(key)
if fileio.exists(filepath):
fileio.remove(filepath)
if not io_utils.is_remote(filepath):
parent_dir = str(Path(filepath).parent)
self.rrmdir(self.root_path, str(parent_dir))
return True
return False
def _has_key(self, key: Tuple[str, ...]) -> bool:
"""Check if an object is present in the store.
Args:
key: object key identifier.
Returns:
True if the object is present in the store, otherwise False.
"""
filepath: str = self._build_object_path(key)
result = fileio.exists(filepath)
return result
def get_url_for_key(
self, key: Tuple[str, ...], protocol: Optional[str] = None
) -> str:
"""Get the URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The URL of the object in the store.
"""
filepath = self._build_object_path(key)
if not protocol and not io_utils.is_remote(filepath):
protocol = "file:"
if protocol:
filepath = filepath.replace(self.proto, f"{protocol}//", 1)
return filepath
def get_public_url_for_key(
self, key: str, protocol: Optional[str] = None
) -> str:
"""Get the public URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The public URL where the object can be accessed.
Raises:
StoreBackendError: if a `base_public_path` attribute was not
configured for the store.
"""
if not self.base_public_path:
raise StoreBackendError(
f"Error: No base_public_path was configured! A public URL was "
f"requested but `base_public_path` was not configured for the "
f"{self.__class__.__name__}"
)
filepath = self._convert_key_to_filepath(key)
public_url = self.base_public_path + filepath.replace(self.proto, "")
return cast(str, public_url)
@staticmethod
def rrmdir(start_path: str, end_path: str) -> None:
"""Recursively removes empty dirs between start_path and end_path inclusive.
Args:
start_path: Directory to use as a starting point.
end_path: Directory to use as a destination point.
"""
while not os.listdir(end_path) and start_path != end_path:
os.rmdir(end_path)
end_path = os.path.dirname(end_path)
@property
def config(self) -> Dict[str, Any]:
"""Get the store configuration.
Returns:
The store configuration.
"""
return self._config
config: Dict[str, Any]
property
readonly
Get the store configuration.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The store configuration. |
__init__(self, prefix='', **kwargs)
special
Create a Great Expectations ZenML store backend instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prefix |
str |
Subpath prefix to use for this store backend. |
'' |
kwargs |
Any |
Additional keyword arguments passed by the Great Expectations
core. These are transparently passed to the |
{} |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def __init__(
self,
prefix: str = "",
**kwargs: Any,
) -> None:
"""Create a Great Expectations ZenML store backend instance.
Args:
prefix: Subpath prefix to use for this store backend.
kwargs: Additional keyword arguments passed by the Great Expectations
core. These are transparently passed to the `TupleStoreBackend`
constructor.
"""
super().__init__(**kwargs)
repo = Repository(skip_repository_check=True) # type: ignore[call-arg]
artifact_store = repo.active_stack.artifact_store
self.root_path = os.path.join(artifact_store.path, "great_expectations")
# extract the protocol used in the artifact store root path
protocols = [
scheme
for scheme in artifact_store.SUPPORTED_SCHEMES
if self.root_path.startswith(scheme)
]
if protocols:
self.proto = protocols[0]
else:
self.proto = ""
if prefix:
if self.platform_specific_separator:
prefix = prefix.strip(os.sep)
prefix = prefix.strip("/")
self.prefix = prefix
# Initialize with store_backend_id if not part of an HTMLSiteStore
if not self._suppress_store_backend_id:
_ = self.store_backend_id
self._config = {
"prefix": prefix,
"module_name": self.__class__.__module__,
"class_name": self.__class__.__name__,
}
self._config.update(kwargs)
filter_properties_dict(
properties=self._config, clean_falsy=True, inplace=True
)
get_public_url_for_key(self, key, protocol=None)
Get the public URL of an object in the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
object key identifier. |
required |
protocol |
Optional[str] |
optional protocol to use instead of the store protocol. |
None |
Returns:
Type | Description |
---|---|
str |
The public URL where the object can be accessed. |
Exceptions:
Type | Description |
---|---|
StoreBackendError |
if a |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_public_url_for_key(
self, key: str, protocol: Optional[str] = None
) -> str:
"""Get the public URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The public URL where the object can be accessed.
Raises:
StoreBackendError: if a `base_public_path` attribute was not
configured for the store.
"""
if not self.base_public_path:
raise StoreBackendError(
f"Error: No base_public_path was configured! A public URL was "
f"requested but `base_public_path` was not configured for the "
f"{self.__class__.__name__}"
)
filepath = self._convert_key_to_filepath(key)
public_url = self.base_public_path + filepath.replace(self.proto, "")
return cast(str, public_url)
get_url_for_key(self, key, protocol=None)
Get the URL of an object in the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Tuple[str, ...] |
object key identifier. |
required |
protocol |
Optional[str] |
optional protocol to use instead of the store protocol. |
None |
Returns:
Type | Description |
---|---|
str |
The URL of the object in the store. |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_url_for_key(
self, key: Tuple[str, ...], protocol: Optional[str] = None
) -> str:
"""Get the URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The URL of the object in the store.
"""
filepath = self._build_object_path(key)
if not protocol and not io_utils.is_remote(filepath):
protocol = "file:"
if protocol:
filepath = filepath.replace(self.proto, f"{protocol}//", 1)
return filepath
list_keys(self, prefix=())
List the keys of all objects identified by a partial key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prefix |
Tuple[str, ...] |
partial object key identifier. |
() |
Returns:
Type | Description |
---|---|
List[Tuple[str, ...]] |
List of keys identifying all objects present in the store that match the input partial key. |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
"""List the keys of all objects identified by a partial key.
Args:
prefix: partial object key identifier.
Returns:
List of keys identifying all objects present in the store that
match the input partial key.
"""
key_list = []
list_path = self._build_object_path(prefix, is_prefix=True)
root_path = self._build_object_path(tuple(), is_prefix=True)
for root, dirs, files in fileio.walk(list_path):
for file_ in files:
filepath = os.path.relpath(
os.path.join(str(root), str(file_)), root_path
)
if self.filepath_prefix and not filepath.startswith(
self.filepath_prefix
):
continue
elif self.filepath_suffix and not filepath.endswith(
self.filepath_suffix
):
continue
key = self._convert_filepath_to_key(filepath)
if key and not self.is_ignored_key(key):
key_list.append(key)
return key_list
remove_key(self, key)
Delete an object from the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Tuple[str, ...] |
object key identifier. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the object existed in the store and was removed, otherwise False. |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def remove_key(self, key: Tuple[str, ...]) -> bool:
"""Delete an object from the store.
Args:
key: object key identifier.
Returns:
True if the object existed in the store and was removed, otherwise
False.
"""
filepath: str = self._build_object_path(key)
if fileio.exists(filepath):
fileio.remove(filepath)
if not io_utils.is_remote(filepath):
parent_dir = str(Path(filepath).parent)
self.rrmdir(self.root_path, str(parent_dir))
return True
return False
rrmdir(start_path, end_path)
staticmethod
Recursively removes empty dirs between start_path and end_path inclusive.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start_path |
str |
Directory to use as a starting point. |
required |
end_path |
str |
Directory to use as a destination point. |
required |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
@staticmethod
def rrmdir(start_path: str, end_path: str) -> None:
"""Recursively removes empty dirs between start_path and end_path inclusive.
Args:
start_path: Directory to use as a starting point.
end_path: Directory to use as a destination point.
"""
while not os.listdir(end_path) and start_path != end_path:
os.rmdir(end_path)
end_path = os.path.dirname(end_path)
materializers
special
Materializers for Great Expectation serializable objects.
ge_materializer
Implementation of the Great Expectations materializers.
GreatExpectationsMaterializer (BaseMaterializer)
Materializer to read/write Great Expectation objects.
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
class GreatExpectationsMaterializer(BaseMaterializer):
"""Materializer to read/write Great Expectation objects."""
ASSOCIATED_TYPES = (
ExpectationSuite,
CheckpointResult,
)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
@staticmethod
def preprocess_checkpoint_result_dict(
artifact_dict: Dict[str, Any]
) -> None:
"""Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.
The GE CheckpointResult object is not fully de-serializable
due to some missing code in the GE codebase. We need to compensate
for this by manually converting some of the attributes to
their correct data types.
Args:
artifact_dict: A dict containing the GE checkpoint result.
"""
def preprocess_run_result(key: str, value: Any) -> Any:
if key == "validation_result":
return ExpectationSuiteValidationResult(**value)
return value
artifact_dict["checkpoint_config"] = CheckpointConfig(
**artifact_dict["checkpoint_config"]
)
validation_dict = {}
for result_ident, results in artifact_dict["run_results"].items():
validation_ident = (
ValidationResultIdentifier.from_fixed_length_tuple(
result_ident.split("::")[1].split("/")
)
)
validation_results = {
result_name: preprocess_run_result(result_name, result)
for result_name, result in results.items()
}
validation_dict[validation_ident] = validation_results
artifact_dict["run_results"] = validation_dict
def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
"""Reads and returns a Great Expectations object.
Args:
data_type: The type of the data to read.
Returns:
A loaded Great Expectations object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = yaml_utils.read_json(filepath)
data_type = import_class_by_path(artifact_dict.pop("data_type"))
if data_type is CheckpointResult:
self.preprocess_checkpoint_result_dict(artifact_dict)
return data_type(**artifact_dict)
def handle_return(self, obj: SerializableDictDot) -> None:
"""Writes a Great Expectations object.
Args:
obj: A Great Expectations object.
"""
super().handle_return(obj)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = obj.to_json_dict()
artifact_type = type(obj)
artifact_dict[
"data_type"
] = f"{artifact_type.__module__}.{artifact_type.__name__}"
yaml_utils.write_json(filepath, artifact_dict)
handle_input(self, data_type)
Reads and returns a Great Expectations object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
SerializableDictDot |
A loaded Great Expectations object. |
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
"""Reads and returns a Great Expectations object.
Args:
data_type: The type of the data to read.
Returns:
A loaded Great Expectations object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = yaml_utils.read_json(filepath)
data_type = import_class_by_path(artifact_dict.pop("data_type"))
if data_type is CheckpointResult:
self.preprocess_checkpoint_result_dict(artifact_dict)
return data_type(**artifact_dict)
handle_return(self, obj)
Writes a Great Expectations object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
SerializableDictDot |
A Great Expectations object. |
required |
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_return(self, obj: SerializableDictDot) -> None:
"""Writes a Great Expectations object.
Args:
obj: A Great Expectations object.
"""
super().handle_return(obj)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = obj.to_json_dict()
artifact_type = type(obj)
artifact_dict[
"data_type"
] = f"{artifact_type.__module__}.{artifact_type.__name__}"
yaml_utils.write_json(filepath, artifact_dict)
preprocess_checkpoint_result_dict(artifact_dict)
staticmethod
Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.
The GE CheckpointResult object is not fully de-serializable due to some missing code in the GE codebase. We need to compensate for this by manually converting some of the attributes to their correct data types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_dict |
Dict[str, Any] |
A dict containing the GE checkpoint result. |
required |
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
@staticmethod
def preprocess_checkpoint_result_dict(
artifact_dict: Dict[str, Any]
) -> None:
"""Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.
The GE CheckpointResult object is not fully de-serializable
due to some missing code in the GE codebase. We need to compensate
for this by manually converting some of the attributes to
their correct data types.
Args:
artifact_dict: A dict containing the GE checkpoint result.
"""
def preprocess_run_result(key: str, value: Any) -> Any:
if key == "validation_result":
return ExpectationSuiteValidationResult(**value)
return value
artifact_dict["checkpoint_config"] = CheckpointConfig(
**artifact_dict["checkpoint_config"]
)
validation_dict = {}
for result_ident, results in artifact_dict["run_results"].items():
validation_ident = (
ValidationResultIdentifier.from_fixed_length_tuple(
result_ident.split("::")[1].split("/")
)
)
validation_results = {
result_name: preprocess_run_result(result_name, result)
for result_name, result in results.items()
}
validation_dict[validation_ident] = validation_results
artifact_dict["run_results"] = validation_dict
steps
special
Great Expectations data profiling and validation standard steps.
ge_profiler
Great Expectations data profiling standard step.
GreatExpectationsProfilerConfig (BaseStepConfig)
pydantic-model
Config class for a Great Expectations profiler step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to create or update. |
data_asset_name |
Optional[str] |
The name of the data asset to run the expectation suite on. |
profiler_kwargs |
Optional[Dict[str, Any]] |
A dictionary of keyword arguments to pass to the profiler. |
overwrite_existing_suite |
bool |
Whether to overwrite an existing expectation suite. |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerConfig(BaseStepConfig):
"""Config class for a Great Expectations profiler step.
Attributes:
expectation_suite_name: The name of the expectation suite to create
or update.
data_asset_name: The name of the data asset to run the expectation suite on.
profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
overwrite_existing_suite: Whether to overwrite an existing expectation suite.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
overwrite_existing_suite: bool = True
GreatExpectationsProfilerStep (BaseStep)
Standard Great Expectations profiling step implementation.
Use this standard Great Expectations profiling step to build an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerStep(BaseStep):
"""Standard Great Expectations profiling step implementation.
Use this standard Great Expectations profiling step to build an Expectation
Suite automatically by running a UserConfigurableProfiler on an input
dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: GreatExpectationsProfilerConfig,
) -> ExpectationSuite:
"""Standard Great Expectations data profiling step entrypoint.
Args:
dataset: The dataset from which the expectation suite will be inferred.
config: The configuration for the step.
Returns:
The generated Great Expectations suite.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
return data_validator.data_profiling(
dataset,
expectation_suite_name=config.expectation_suite_name,
data_asset_name=config.data_asset_name,
profiler_kwargs=config.profiler_kwargs,
overwrite_existing_suite=config.overwrite_existing_suite,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for a Great Expectations profiler step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to create or update. |
data_asset_name |
Optional[str] |
The name of the data asset to run the expectation suite on. |
profiler_kwargs |
Optional[Dict[str, Any]] |
A dictionary of keyword arguments to pass to the profiler. |
overwrite_existing_suite |
bool |
Whether to overwrite an existing expectation suite. |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerConfig(BaseStepConfig):
"""Config class for a Great Expectations profiler step.
Attributes:
expectation_suite_name: The name of the expectation suite to create
or update.
data_asset_name: The name of the data asset to run the expectation suite on.
profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
overwrite_existing_suite: Whether to overwrite an existing expectation suite.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
overwrite_existing_suite: bool = True
entrypoint(self, dataset, config)
Standard Great Expectations data profiling step entrypoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset from which the expectation suite will be inferred. |
required |
config |
GreatExpectationsProfilerConfig |
The configuration for the step. |
required |
Returns:
Type | Description |
---|---|
ExpectationSuite |
The generated Great Expectations suite. |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: GreatExpectationsProfilerConfig,
) -> ExpectationSuite:
"""Standard Great Expectations data profiling step entrypoint.
Args:
dataset: The dataset from which the expectation suite will be inferred.
config: The configuration for the step.
Returns:
The generated Great Expectations suite.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
return data_validator.data_profiling(
dataset,
expectation_suite_name=config.expectation_suite_name,
data_asset_name=config.data_asset_name,
profiler_kwargs=config.profiler_kwargs,
overwrite_existing_suite=config.overwrite_existing_suite,
)
great_expectations_profiler_step(step_name, config)
Shortcut function to create a new instance of the GreatExpectationsProfilerStep step.
The returned GreatExpectationsProfilerStep can be used in a pipeline to infer data validation rules from an input pd.DataFrame dataset and return them as an Expectation Suite. The Expectation Suite is also persisted in the Great Expectations expectation store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
GreatExpectationsProfilerConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a GreatExpectationsProfilerStep step instance |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def great_expectations_profiler_step(
step_name: str,
config: GreatExpectationsProfilerConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the GreatExpectationsProfilerStep step.
The returned GreatExpectationsProfilerStep can be used in a pipeline to
infer data validation rules from an input pd.DataFrame dataset and return
them as an Expectation Suite. The Expectation Suite is also persisted in the
Great Expectations expectation store.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a GreatExpectationsProfilerStep step instance
"""
return clone_step(GreatExpectationsProfilerStep, step_name)(config=config)
ge_validator
Great Expectations data validation standard step.
GreatExpectationsValidatorConfig (BaseStepConfig)
pydantic-model
Config class for a Great Expectations checkpoint step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to use to validate the dataset. |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
action_list |
Optional[List[Dict[str, Any]]] |
A list of additional Great Expectations actions to run after the validation check. |
exit_on_error |
bool |
Set this flag to raise an error and exit the pipeline early if the validation fails. |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorConfig(BaseStepConfig):
"""Config class for a Great Expectations checkpoint step.
Attributes:
expectation_suite_name: The name of the expectation suite to use to
validate the dataset.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run
after the validation check.
exit_on_error: Set this flag to raise an error and exit the pipeline
early if the validation fails.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
action_list: Optional[List[Dict[str, Any]]] = None
exit_on_error: bool = False
GreatExpectationsValidatorStep (BaseStep)
Standard Great Expectations data validation step implementation.
Use this standard Great Expectations data validation step to run an existing Expectation Suite on an input dataset as covered in the official GE documentation.
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorStep(BaseStep):
"""Standard Great Expectations data validation step implementation.
Use this standard Great Expectations data validation step to run an
existing Expectation Suite on an input dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
condition: bool,
config: GreatExpectationsValidatorConfig,
) -> CheckpointResult:
"""Standard Great Expectations data validation step entrypoint.
Args:
dataset: The dataset to run the expectation suite on.
condition: This dummy argument can be used as a condition to enforce
that this step is only run after another step has completed. This
is useful for example if the Expectation Suite used to validate
the data is computed in a `GreatExpectationsProfilerStep` that
is part of the same pipeline.
config: The configuration for the step.
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
RuntimeError: if the step is configured to exit on error and the
data validation failed.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
results = data_validator.data_validation(
dataset,
expectation_suite_name=config.expectation_suite_name,
data_asset_name=config.data_asset_name,
action_list=config.action_list,
)
if config.exit_on_error and not results.success():
raise RuntimeError(
"The Great Expectations validation failed. Check "
"the logs or the Great Expectations data docs for more "
"information."
)
return results
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for a Great Expectations checkpoint step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to use to validate the dataset. |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
action_list |
Optional[List[Dict[str, Any]]] |
A list of additional Great Expectations actions to run after the validation check. |
exit_on_error |
bool |
Set this flag to raise an error and exit the pipeline early if the validation fails. |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorConfig(BaseStepConfig):
"""Config class for a Great Expectations checkpoint step.
Attributes:
expectation_suite_name: The name of the expectation suite to use to
validate the dataset.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run
after the validation check.
exit_on_error: Set this flag to raise an error and exit the pipeline
early if the validation fails.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
action_list: Optional[List[Dict[str, Any]]] = None
exit_on_error: bool = False
entrypoint(self, dataset, condition, config)
Standard Great Expectations data validation step entrypoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset to run the expectation suite on. |
required |
condition |
bool |
This dummy argument can be used as a condition to enforce
that this step is only run after another step has completed. This
is useful for example if the Expectation Suite used to validate
the data is computed in a |
required |
config |
GreatExpectationsValidatorConfig |
The configuration for the step. |
required |
Returns:
Type | Description |
---|---|
CheckpointResult |
The Great Expectations validation (checkpoint) result. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the step is configured to exit on error and the data validation failed. |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
condition: bool,
config: GreatExpectationsValidatorConfig,
) -> CheckpointResult:
"""Standard Great Expectations data validation step entrypoint.
Args:
dataset: The dataset to run the expectation suite on.
condition: This dummy argument can be used as a condition to enforce
that this step is only run after another step has completed. This
is useful for example if the Expectation Suite used to validate
the data is computed in a `GreatExpectationsProfilerStep` that
is part of the same pipeline.
config: The configuration for the step.
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
RuntimeError: if the step is configured to exit on error and the
data validation failed.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
results = data_validator.data_validation(
dataset,
expectation_suite_name=config.expectation_suite_name,
data_asset_name=config.data_asset_name,
action_list=config.action_list,
)
if config.exit_on_error and not results.success():
raise RuntimeError(
"The Great Expectations validation failed. Check "
"the logs or the Great Expectations data docs for more "
"information."
)
return results
great_expectations_validator_step(step_name, config)
Shortcut function to create a new instance of the GreatExpectationsValidatorStep step.
The returned GreatExpectationsValidatorStep can be used in a pipeline to validate an input pd.DataFrame dataset and return the result as a Great Expectations CheckpointResult object. The validation results are also persisted in the Great Expectations validation store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
GreatExpectationsValidatorConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a GreatExpectationsProfilerStep step instance |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def great_expectations_validator_step(
step_name: str,
config: GreatExpectationsValidatorConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the GreatExpectationsValidatorStep step.
The returned GreatExpectationsValidatorStep can be used in a pipeline to
validate an input pd.DataFrame dataset and return the result as a Great
Expectations CheckpointResult object. The validation results are also
persisted in the Great Expectations validation store.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a GreatExpectationsProfilerStep step instance
"""
return clone_step(GreatExpectationsValidatorStep, step_name)(config=config)
utils
Great Expectations data profiling standard step.
create_batch_request(context, dataset, data_asset_name)
Create a temporary runtime GE batch request from a dataset step artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
BaseDataContext |
Great Expectations data context. |
required |
dataset |
DataFrame |
Input dataset. |
required |
data_asset_name |
Optional[str] |
Optional custom name for the data asset. |
required |
Returns:
Type | Description |
---|---|
RuntimeBatchRequest |
A Great Expectations runtime batch request. |
Source code in zenml/integrations/great_expectations/utils.py
def create_batch_request(
context: BaseDataContext,
dataset: pd.DataFrame,
data_asset_name: Optional[str],
) -> RuntimeBatchRequest:
"""Create a temporary runtime GE batch request from a dataset step artifact.
Args:
context: Great Expectations data context.
dataset: Input dataset.
data_asset_name: Optional custom name for the data asset.
Returns:
A Great Expectations runtime batch request.
"""
try:
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
except KeyError:
# if not running inside a pipeline step, use random values
pipeline_name = f"pipeline_{random_str(5)}"
run_id = f"pipeline_{random_str(5)}"
step_name = f"step_{random_str(5)}"
datasource_name = f"{run_id}_{step_name}"
data_connector_name = datasource_name
data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}"
batch_identifier = "default"
datasource_config = {
"name": datasource_name,
"class_name": "Datasource",
"module_name": "great_expectations.datasource",
"execution_engine": {
"module_name": "great_expectations.execution_engine",
"class_name": "PandasExecutionEngine",
},
"data_connectors": {
data_connector_name: {
"class_name": "RuntimeDataConnector",
"batch_identifiers": [batch_identifier],
},
},
}
context.add_datasource(**datasource_config)
batch_request = RuntimeBatchRequest(
datasource_name=datasource_name,
data_connector_name=data_connector_name,
data_asset_name=data_asset_name,
runtime_parameters={"batch_data": dataset},
batch_identifiers={batch_identifier: batch_identifier},
)
return batch_request
visualizers
special
Great Expectations visualizers for expectation suites and validation results.
ge_visualizer
Great Expectations visualizers for expectation suites and validation results.
GreatExpectationsVisualizer (BaseStepVisualizer)
The implementation of a Great Expectations Visualizer.
Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
class GreatExpectationsVisualizer(BaseStepVisualizer):
"""The implementation of a Great Expectations Visualizer."""
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize a Great Expectations resource.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but Great Expectations data analysis artifacts
if (
artifact_view.type == DataAnalysisArtifact.__name__
and artifact_view.data_type.startswith("great_expectations.")
):
artifact = artifact_view.read()
if isinstance(artifact, CheckpointResult):
result = cast(CheckpointResult, artifact)
identifier = next(iter(result.run_results.keys()))
else:
suite = cast(ExpectationSuite, artifact)
identifier = ExpectationSuiteIdentifier(
suite.expectation_suite_name
)
context = GreatExpectationsDataValidator.get_data_context()
context.open_data_docs(identifier)
visualize(self, object, *args, **kwargs)
Method to visualize a Great Expectations resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize a Great Expectations resource.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but Great Expectations data analysis artifacts
if (
artifact_view.type == DataAnalysisArtifact.__name__
and artifact_view.data_type.startswith("great_expectations.")
):
artifact = artifact_view.read()
if isinstance(artifact, CheckpointResult):
result = cast(CheckpointResult, artifact)
identifier = next(iter(result.run_results.keys()))
else:
suite = cast(ExpectationSuite, artifact)
identifier = ExpectationSuiteIdentifier(
suite.expectation_suite_name
)
context = GreatExpectationsDataValidator.get_data_context()
context.open_data_docs(identifier)
huggingface
special
Initialization of the Huggingface integration.
HuggingfaceIntegration (Integration)
Definition of Huggingface integration for ZenML.
Source code in zenml/integrations/huggingface/__init__.py
class HuggingfaceIntegration(Integration):
"""Definition of Huggingface integration for ZenML."""
NAME = HUGGINGFACE
REQUIREMENTS = ["transformers", "datasets"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.huggingface import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/huggingface/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.huggingface import materializers # noqa
materializers
special
Initialization of Huggingface materializers.
huggingface_datasets_materializer
Implementation of the Huggingface datasets materializer.
HFDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from huggingface datasets.
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
class HFDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from huggingface datasets."""
ASSOCIATED_TYPES = (Dataset, DatasetDict)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads Dataset.
Args:
data_type: The type of the dataset to read.
Returns:
The dataset read from the specified dir.
"""
super().handle_input(data_type)
return load_from_disk(
os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
def handle_return(self, ds: Type[Any]) -> None:
"""Writes a Dataset to the specified dir.
Args:
ds: The Dataset to write.
"""
super().handle_return(ds)
temp_dir = TemporaryDirectory()
ds.save_to_disk(temp_dir.name)
io_utils.copy_dir(
temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
handle_input(self, data_type)
Reads Dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the dataset to read. |
required |
Returns:
Type | Description |
---|---|
Dataset |
The dataset read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads Dataset.
Args:
data_type: The type of the dataset to read.
Returns:
The dataset read from the specified dir.
"""
super().handle_input(data_type)
return load_from_disk(
os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
handle_return(self, ds)
Writes a Dataset to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ds |
Type[Any] |
The Dataset to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_return(self, ds: Type[Any]) -> None:
"""Writes a Dataset to the specified dir.
Args:
ds: The Dataset to write.
"""
super().handle_return(ds)
temp_dir = TemporaryDirectory()
ds.save_to_disk(temp_dir.name)
io_utils.copy_dir(
temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
huggingface_pt_model_materializer
Implementation of the Huggingface PyTorch model materializer.
HFPTModelMaterializer (BaseMaterializer)
Materializer to read torch model to and from huggingface pretrained model.
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
class HFPTModelMaterializer(BaseMaterializer):
"""Materializer to read torch model to and from huggingface pretrained model."""
ASSOCIATED_TYPES = (PreTrainedModel,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
architecture = config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The Torch Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
)
handle_input(self, data_type)
Reads HFModel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to read. |
required |
Returns:
Type | Description |
---|---|
PreTrainedModel |
The model read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
architecture = config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
handle_return(self, model)
Writes a Model to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Type[Any] |
The Torch Model to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The Torch Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
)
huggingface_tf_model_materializer
Implementation of the Huggingface TF model materializer.
HFTFModelMaterializer (BaseMaterializer)
Materializer to read Tensorflow model to and from huggingface pretrained model.
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
class HFTFModelMaterializer(BaseMaterializer):
"""Materializer to read Tensorflow model to and from huggingface pretrained model."""
ASSOCIATED_TYPES = (TFPreTrainedModel,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
architecture = "TF" + config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The TF Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
)
handle_input(self, data_type)
Reads HFModel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to read. |
required |
Returns:
Type | Description |
---|---|
TFPreTrainedModel |
The model read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
architecture = "TF" + config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
handle_return(self, model)
Writes a Model to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Type[Any] |
The TF Model to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The TF Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
)
huggingface_tokenizer_materializer
Implementation of the Huggingface tokenizer materializer.
HFTokenizerMaterializer (BaseMaterializer)
Materializer to read tokenizer to and from huggingface tokenizer.
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
class HFTokenizerMaterializer(BaseMaterializer):
"""Materializer to read tokenizer to and from huggingface tokenizer."""
ASSOCIATED_TYPES = (PreTrainedTokenizerBase,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
"""Reads Tokenizer.
Args:
data_type: The type of the tokenizer to read.
Returns:
The tokenizer read from the specified dir.
"""
super().handle_input(data_type)
return AutoTokenizer.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
)
def handle_return(self, tokenizer: Type[Any]) -> None:
"""Writes a Tokenizer to the specified dir.
Args:
tokenizer: The HFTokenizer to write.
"""
super().handle_return(tokenizer)
temp_dir = TemporaryDirectory()
tokenizer.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
)
handle_input(self, data_type)
Reads Tokenizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the tokenizer to read. |
required |
Returns:
Type | Description |
---|---|
PreTrainedTokenizerBase |
The tokenizer read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
"""Reads Tokenizer.
Args:
data_type: The type of the tokenizer to read.
Returns:
The tokenizer read from the specified dir.
"""
super().handle_input(data_type)
return AutoTokenizer.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
)
handle_return(self, tokenizer)
Writes a Tokenizer to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tokenizer |
Type[Any] |
The HFTokenizer to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_return(self, tokenizer: Type[Any]) -> None:
"""Writes a Tokenizer to the specified dir.
Args:
tokenizer: The HFTokenizer to write.
"""
super().handle_return(tokenizer)
temp_dir = TemporaryDirectory()
tokenizer.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
)
integration
Base and meta classes for ZenML integrations.
Integration
Base class for integration in ZenML.
Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
"""Base class for integration in ZenML."""
NAME = "base_integration"
REQUIREMENTS: List[str] = []
SYSTEM_REQUIREMENTS: Dict[str, str] = {}
@classmethod
def check_installation(cls) -> bool:
"""Method to check whether the required packages are installed.
Returns:
True if all required packages are installed, False otherwise.
"""
try:
for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
result = shutil.which(command)
if result is None:
logger.debug(
"Unable to find the required packages for %s on your "
"system. Please install the packages on your system "
"and try again.",
requirement,
)
return False
for r in cls.REQUIREMENTS:
pkg_resources.get_distribution(r)
logger.debug(
f"Integration {cls.NAME} is installed correctly with "
f"requirements {cls.REQUIREMENTS}."
)
return True
except pkg_resources.DistributionNotFound as e:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
logger.debug(
f"VersionConflict error when loading installation {cls.NAME}: "
f"{str(e)}"
)
return False
@classmethod
def activate(cls) -> None:
"""Abstract method to activate the integration."""
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Abstract method to declare new stack component flavors."""
activate()
classmethod
Abstract method to activate the integration.
Source code in zenml/integrations/integration.py
@classmethod
def activate(cls) -> None:
"""Abstract method to activate the integration."""
check_installation()
classmethod
Method to check whether the required packages are installed.
Returns:
Type | Description |
---|---|
bool |
True if all required packages are installed, False otherwise. |
Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
"""Method to check whether the required packages are installed.
Returns:
True if all required packages are installed, False otherwise.
"""
try:
for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
result = shutil.which(command)
if result is None:
logger.debug(
"Unable to find the required packages for %s on your "
"system. Please install the packages on your system "
"and try again.",
requirement,
)
return False
for r in cls.REQUIREMENTS:
pkg_resources.get_distribution(r)
logger.debug(
f"Integration {cls.NAME} is installed correctly with "
f"requirements {cls.REQUIREMENTS}."
)
return True
except pkg_resources.DistributionNotFound as e:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
logger.debug(
f"VersionConflict error when loading installation {cls.NAME}: "
f"{str(e)}"
)
return False
flavors()
classmethod
Abstract method to declare new stack component flavors.
Source code in zenml/integrations/integration.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Abstract method to declare new stack component flavors."""
IntegrationMeta (type)
Metaclass responsible for registering different Integration subclasses.
Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
"""Metaclass responsible for registering different Integration subclasses."""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
"""Hook into creation of an Integration class.
Args:
name: The name of the class being created.
bases: The base classes of the class being created.
dct: The dictionary of attributes of the class being created.
Returns:
The newly created class.
"""
cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
if name != "Integration":
integration_registry.register_integration(cls.NAME, cls)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Hook into creation of an Integration class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class being created. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class being created. |
required |
dct |
Dict[str, Any] |
The dictionary of attributes of the class being created. |
required |
Returns:
Type | Description |
---|---|
IntegrationMeta |
The newly created class. |
Source code in zenml/integrations/integration.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
"""Hook into creation of an Integration class.
Args:
name: The name of the class being created.
bases: The base classes of the class being created.
dct: The dictionary of attributes of the class being created.
Returns:
The newly created class.
"""
cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
if name != "Integration":
integration_registry.register_integration(cls.NAME, cls)
return cls
kserve
special
Initialization of the KServe integration for ZenML.
The KServe integration allows you to use the KServe model serving platform to implement continuous model deployment.
KServeIntegration (Integration)
Definition of KServe integration for ZenML.
Source code in zenml/integrations/kserve/__init__.py
class KServeIntegration(Integration):
"""Definition of KServe integration for ZenML."""
NAME = KSERVE
REQUIREMENTS = [
"kserve==0.9.0",
"torch-model-archiver",
]
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.kserve import model_deployers # noqa
from zenml.integrations.kserve import secret_schemas # noqa
from zenml.integrations.kserve import services # noqa
from zenml.integrations.kserve import steps # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for KServe.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=KSERVE_MODEL_DEPLOYER_FLAVOR,
source="zenml.integrations.kserve.model_deployers.KServeModelDeployer",
type=StackComponentType.MODEL_DEPLOYER,
integration=cls.NAME,
)
]
activate()
classmethod
Activate the Seldon Core integration.
Source code in zenml/integrations/kserve/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.kserve import model_deployers # noqa
from zenml.integrations.kserve import secret_schemas # noqa
from zenml.integrations.kserve import services # noqa
from zenml.integrations.kserve import steps # noqa
flavors()
classmethod
Declare the stack component flavors for KServe.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/kserve/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for KServe.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=KSERVE_MODEL_DEPLOYER_FLAVOR,
source="zenml.integrations.kserve.model_deployers.KServeModelDeployer",
type=StackComponentType.MODEL_DEPLOYER,
integration=cls.NAME,
)
]
model_deployers
special
Initialization of the KServe Model Deployer.
kserve_model_deployer
Implementation of the KServe Model Deployer.
KServeModelDeployer (BaseModelDeployer)
pydantic-model
KServe model deployer stack component implementation.
Attributes:
Name | Type | Description |
---|---|---|
kubernetes_context |
Optional[str] |
the Kubernetes context to use to contact the remote KServe installation. If not specified, the current configuration is used. Depending on where the KServe model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod). |
kubernetes_namespace |
Optional[str] |
the Kubernetes namespace where the KServe inference service CRDs are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the KServe model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod). |
base_url |
str |
the base URL of the Kubernetes ingress used to expose the KServe inference services. |
secret |
Optional[str] |
the name of the secret containing the credentials for the KServe inference services. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
class KServeModelDeployer(BaseModelDeployer):
"""KServe model deployer stack component implementation.
Attributes:
kubernetes_context: the Kubernetes context to use to contact the remote
KServe installation. If not specified, the current
configuration is used. Depending on where the KServe model deployer
is being used, this can be either a locally active context or an
in-cluster Kubernetes configuration (if running inside a pod).
kubernetes_namespace: the Kubernetes namespace where the KServe
inference service CRDs are provisioned and managed by ZenML. If not
specified, the namespace set in the current configuration is used.
Depending on where the KServe model deployer is being used, this can
be either the current namespace configured in the locally active
context or the namespace in the context of which the pod is running
(if running inside a pod).
base_url: the base URL of the Kubernetes ingress used to expose the
KServe inference services.
secret: the name of the secret containing the credentials for the
KServe inference services.
"""
# Class Configuration
FLAVOR: ClassVar[str] = KSERVE_MODEL_DEPLOYER_FLAVOR
kubernetes_context: Optional[str]
kubernetes_namespace: Optional[str]
base_url: str
secret: Optional[str]
custom_domain: Optional[str]
# private attributes
_client: Optional[KServeClient] = None
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "KServeDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information on the model server.
Args:
service_instance: KServe deployment service object
Returns:
A dictionary containing the model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"PREDICTION_HOSTNAME": service_instance.prediction_hostname,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"KSERVE_INFERENCE_SERVICE": service_instance.crd_name,
}
@staticmethod
def get_active_model_deployer() -> "KServeModelDeployer":
"""Get the KServe model deployer registered in the active stack.
Returns:
The KServe model deployer registered in the active stack.
Raises:
TypeError: if the KServe model deployer is not available.
"""
model_deployer = Repository( # type: ignore [call-arg]
skip_repository_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, KServeModelDeployer
):
raise TypeError(
f"The active stack needs to have a KServe model deployer "
f"component registered to be able to deploy models with KServe "
f"You can create a new stack with a KServe model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register kserve --flavor={KSERVE_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d kserve ...'\n"
)
return model_deployer
@property
def kserve_client(self) -> KServeClient:
"""Get the KServe client associated with this model deployer.
Returns:
The KServeclient.
"""
if not self._client:
self._client = KServeClient(
context=self.kubernetes_context,
)
return self._client
def _set_credentials(self) -> None:
"""Set the credentials for the given service instance.
Raises:
RuntimeError: if the credentials are not available.
"""
secret = self._get_kserve_secret()
if secret:
secret_folder = Path(
GlobalConfiguration().config_directory,
"kserve-storage",
str(self.uuid),
)
kserve_credentials = {}
# Handle the secrets attributes
for key in secret.content.keys():
content = getattr(secret, key)
if key == "credentials" and content:
fileio.makedirs(str(secret_folder))
file_path = Path(secret_folder, f"{key}.json")
kserve_credentials["credentials_file"] = str(file_path)
with open(file_path, "w") as f:
f.write(content)
file_path.chmod(0o600)
# Handle additional params
else:
kserve_credentials[key] = content
# We need to add the namespace to the kserve_credentials
kserve_credentials["namespace"] = (
self.kubernetes_namespace
or utils.get_default_target_namespace()
)
try:
self.kserve_client.set_credentials(**kserve_credentials)
except Exception as e:
raise RuntimeError(
f"Failed to set credentials for KServe model deployer: {e}"
)
finally:
if file_path.exists():
file_path.unlink()
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new KServe deployment or update an existing one.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new KServe
deployment server to reflect the model and other configuration
parameters specified in the supplied KServe deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing KServe deployment that is *equivalent* to the supplied
configuration parameters. Two or more KServe deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two KServe deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent KServe deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing KServe deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new KServe deployment
server for each new model version. If multiple equivalent KServe
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with KServe.
replace: set this flag to True to find and update an equivalent
KServeDeployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the KServe server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the KServe
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML KServe deployment service object that can be used to
interact with the remote KServe server.
Raises:
RuntimeError: if the KServe deployment server could not be stopped.
"""
config = cast(KServeDeploymentConfig, config)
service = None
# if the secret is passed in the config, use it to set the credentials
if config.secret_name:
self.secret = config.secret_name or self.secret
self._set_credentials()
# if replace is True, find equivalent KServe deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError as e:
raise RuntimeError(
"Failed to stop the KServe deployment server:\n",
f"{e}\n",
"Please stop it manually and try again.",
)
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing KServe deployment service: {service}"
)
else:
# create a new service
service = KServeDeploymentService(config=config)
logger.info(f"Creating a new KServe deployment service: {service}")
# start the service which in turn provisions the KServe
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
return service
def get_kserve_deployments(
self, labels: Dict[str, str]
) -> List[V1beta1InferenceService]:
"""Get a list of KServe deployments that match the supplied labels.
Args:
labels: a dictionary of labels to match against KServe deployments.
Returns:
A list of KServe deployments that match the supplied labels.
Raises:
RuntimeError: if an operational failure is encountered while
"""
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
namespace = (
self.kubernetes_namespace or utils.get_default_target_namespace()
)
try:
response = (
self.kserve_client.api_instance.list_namespaced_custom_object(
constants.KSERVE_GROUP,
constants.KSERVE_V1BETA1_VERSION,
namespace,
constants.KSERVE_PLURAL,
label_selector=label_selector,
)
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when retrieving KServe inference services\
%s\n"
% e
)
# TODO[CRITICAL]: de-serialize each item into a complete
# V1beta1InferenceService object recursively using the OpenApi
# schema (this doesn't work right now)
inference_services: List[V1beta1InferenceService] = []
for item in response.get("items", []):
snake_case_item = self._camel_to_snake(item)
inference_service = V1beta1InferenceService(**snake_case_item)
inference_services.append(inference_service)
return inference_services
def _camel_to_snake(self, obj: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a camelCase dictionary to snake_case.
Args:
obj: a dictionary with camelCase keys
Returns:
a dictionary with snake_case keys
"""
if isinstance(obj, (str, int, float)):
return obj
if isinstance(obj, dict):
assert obj is not None
new = obj.__class__()
for k, v in obj.items():
new[self._convert_to_snake(k)] = self._camel_to_snake(v)
elif isinstance(obj, (list, set, tuple)):
assert obj is not None
new = obj.__class__(self._camel_to_snake(v) for v in obj)
else:
return obj
return new
def _convert_to_snake(self, k: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", k).lower()
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
predictor: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more KServe model services that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
predictor: the name of the predictor that was used to deploy the model.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
"""
config = KServeDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_uri=model_uri or "",
model_name=model_name or "",
predictor=predictor or "",
resources={},
)
labels = config.get_kubernetes_labels()
if service_uuid:
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.get_kserve_deployments(labels=labels)
services: List[BaseService] = []
for deployment in deployments:
# recreate the KServe deployment service object from the KServe
# deployment resource
service = KServeDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a KServe model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping on KServe model servers is not
supported.
"""
raise NotImplementedError(
"Stopping KServe model servers is not implemented. Try "
"deleting the KServe model server instead."
)
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a KServe model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting KServe
model servers
"""
raise NotImplementedError(
"Starting KServe model servers is not implemented"
)
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a KServe model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
def _get_kserve_secret(self) -> Any:
"""Get the secret object for the KServe deployment.
Returns:
The secret object for the KServe deployment.
Raises:
RuntimeError: if the secret object is not found or secrets_manager is not set.
"""
if self.secret:
secret_manager = Repository( # type: ignore [call-arg]
skip_repository_check=True
).active_stack.secrets_manager
if not secret_manager or not isinstance(
secret_manager, BaseSecretsManager
):
raise RuntimeError(
f"The active stack doesn't have a secret manager component. "
f"The ZenML secret specified in the KServe Model "
f"Deployer configuration cannot be fetched: {self.secret}."
)
try:
secret = secret_manager.get_secret(self.secret)
return secret
except KeyError:
raise RuntimeError(
f"The secret `{self.secret}` used for your KServe Model"
f"Deployer configuration does not exist in your secrets "
f"manager `{secret_manager.name}`."
)
return None
kserve_client: KServeClient
property
readonly
Get the KServe client associated with this model deployer.
Returns:
Type | Description |
---|---|
KServeClient |
The KServeclient. |
delete_model_server(self, uuid, timeout=300, force=False)
Delete a KServe model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to delete. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a KServe model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
deploy_model(self, config, replace=False, timeout=300)
Create a new KServe deployment or update an existing one.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new KServe deployment server to reflect the model and other configuration parameters specified in the supplied KServe deploymentconfig
. -
if
replace
is True, this method will first attempt to find an existing KServe deployment that is equivalent to the supplied configuration parameters. Two or more KServe deployments are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two KServe deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent KServe deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing KServe deployment to retain its prediction URL while performing a rolling update to serve a new model version.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new KServe deployment
server for each new model version. If multiple equivalent KServe
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServiceConfig |
the configuration of the model to be deployed with KServe. |
required |
replace |
bool |
set this flag to True to find and update an equivalent KServeDeployment server with the new model instead of starting a new deployment server. |
False |
timeout |
int |
the timeout in seconds to wait for the KServe server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the KServe server is provisioned, without waiting for it to fully start. |
300 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML KServe deployment service object that can be used to interact with the remote KServe server. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the KServe deployment server could not be stopped. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new KServe deployment or update an existing one.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new KServe
deployment server to reflect the model and other configuration
parameters specified in the supplied KServe deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing KServe deployment that is *equivalent* to the supplied
configuration parameters. Two or more KServe deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two KServe deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent KServe deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing KServe deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new KServe deployment
server for each new model version. If multiple equivalent KServe
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with KServe.
replace: set this flag to True to find and update an equivalent
KServeDeployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the KServe server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the KServe
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML KServe deployment service object that can be used to
interact with the remote KServe server.
Raises:
RuntimeError: if the KServe deployment server could not be stopped.
"""
config = cast(KServeDeploymentConfig, config)
service = None
# if the secret is passed in the config, use it to set the credentials
if config.secret_name:
self.secret = config.secret_name or self.secret
self._set_credentials()
# if replace is True, find equivalent KServe deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError as e:
raise RuntimeError(
"Failed to stop the KServe deployment server:\n",
f"{e}\n",
"Please stop it manually and try again.",
)
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing KServe deployment service: {service}"
)
else:
# create a new service
service = KServeDeploymentService(config=config)
logger.info(f"Creating a new KServe deployment service: {service}")
# start the service which in turn provisions the KServe
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, predictor=None)
Find one or more KServe model services that match the given criteria.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running |
bool |
If true, only running services will be returned. |
False |
service_uuid |
Optional[uuid.UUID] |
The UUID of the service that was originally used to deploy the model. |
None |
pipeline_name |
Optional[str] |
name of the pipeline that the deployed model was part of. |
None |
pipeline_run_id |
Optional[str] |
ID of the pipeline run which the deployed model was part of. |
None |
pipeline_step_name |
Optional[str] |
the name of the pipeline model deployment step that deployed the model. |
None |
model_name |
Optional[str] |
the name of the deployed model. |
None |
model_uri |
Optional[str] |
URI of the deployed model. |
None |
predictor |
Optional[str] |
the name of the predictor that was used to deploy the model. |
None |
Returns:
Type | Description |
---|---|
List[zenml.services.service.BaseService] |
One or more Service objects representing model servers that match the input search criteria. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
predictor: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more KServe model services that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
predictor: the name of the predictor that was used to deploy the model.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
"""
config = KServeDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_uri=model_uri or "",
model_name=model_name or "",
predictor=predictor or "",
resources={},
)
labels = config.get_kubernetes_labels()
if service_uuid:
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.get_kserve_deployments(labels=labels)
services: List[BaseService] = []
for deployment in deployments:
# recreate the KServe deployment service object from the KServe
# deployment resource
service = KServeDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
get_active_model_deployer()
staticmethod
Get the KServe model deployer registered in the active stack.
Returns:
Type | Description |
---|---|
KServeModelDeployer |
The KServe model deployer registered in the active stack. |
Exceptions:
Type | Description |
---|---|
TypeError |
if the KServe model deployer is not available. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "KServeModelDeployer":
"""Get the KServe model deployer registered in the active stack.
Returns:
The KServe model deployer registered in the active stack.
Raises:
TypeError: if the KServe model deployer is not available.
"""
model_deployer = Repository( # type: ignore [call-arg]
skip_repository_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, KServeModelDeployer
):
raise TypeError(
f"The active stack needs to have a KServe model deployer "
f"component registered to be able to deploy models with KServe "
f"You can create a new stack with a KServe model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register kserve --flavor={KSERVE_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d kserve ...'\n"
)
return model_deployer
get_kserve_deployments(self, labels)
Get a list of KServe deployments that match the supplied labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
a dictionary of labels to match against KServe deployments. |
required |
Returns:
Type | Description |
---|---|
List[kserve.models.v1beta1_inference_service.V1beta1InferenceService] |
A list of KServe deployments that match the supplied labels. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if an operational failure is encountered while |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def get_kserve_deployments(
self, labels: Dict[str, str]
) -> List[V1beta1InferenceService]:
"""Get a list of KServe deployments that match the supplied labels.
Args:
labels: a dictionary of labels to match against KServe deployments.
Returns:
A list of KServe deployments that match the supplied labels.
Raises:
RuntimeError: if an operational failure is encountered while
"""
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
namespace = (
self.kubernetes_namespace or utils.get_default_target_namespace()
)
try:
response = (
self.kserve_client.api_instance.list_namespaced_custom_object(
constants.KSERVE_GROUP,
constants.KSERVE_V1BETA1_VERSION,
namespace,
constants.KSERVE_PLURAL,
label_selector=label_selector,
)
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when retrieving KServe inference services\
%s\n"
% e
)
# TODO[CRITICAL]: de-serialize each item into a complete
# V1beta1InferenceService object recursively using the OpenApi
# schema (this doesn't work right now)
inference_services: List[V1beta1InferenceService] = []
for item in response.get("items", []):
snake_case_item = self._camel_to_snake(item)
inference_service = V1beta1InferenceService(**snake_case_item)
inference_services.append(inference_service)
return inference_services
get_model_server_info(service_instance)
staticmethod
Return implementation specific information on the model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
KServeDeploymentService |
KServe deployment service object |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
A dictionary containing the model server information. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "KServeDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information on the model server.
Args:
service_instance: KServe deployment service object
Returns:
A dictionary containing the model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"PREDICTION_HOSTNAME": service_instance.prediction_hostname,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"KSERVE_INFERENCE_SERVICE": service_instance.crd_name,
}
start_model_server(self, uuid, timeout=300)
Start a KServe model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to start. |
required |
timeout |
int |
timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active. |
300 |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
since we don't support starting KServe model servers |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a KServe model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting KServe
model servers
"""
raise NotImplementedError(
"Starting KServe model servers is not implemented"
)
stop_model_server(self, uuid, timeout=300, force=False)
Stop a KServe model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to stop. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
stopping on KServe model servers is not supported. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a KServe model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping on KServe model servers is not
supported.
"""
raise NotImplementedError(
"Stopping KServe model servers is not implemented. Try "
"deleting the KServe model server instead."
)
secret_schemas
special
Initialization of Kserve Secret Schemas.
These are secret schemas that can be used to authenticate Kserve to the Artifact Store used to store served ML models.
secret_schemas
Implementation for KServe secret schemas.
KServeAzureSecretSchema (BaseSecretSchema)
pydantic-model
KServe Azure Blob Storage credentials.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
Literal['Azure'] |
the storage type. Must be set to "GCS" for this schema. |
credentials |
Optional[str] |
the credentials to use. |
Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeAzureSecretSchema(BaseSecretSchema):
"""KServe Azure Blob Storage credentials.
Attributes:
storage_type: the storage type. Must be set to "GCS" for this schema.
credentials: the credentials to use.
"""
TYPE: ClassVar[str] = KSERVE_AZUREBLOB_SECRET_SCHEMA_TYPE
storage_type: Literal["Azure"] = "Azure"
credentials: Optional[str]
KServeGSSecretSchema (BaseSecretSchema)
pydantic-model
KServe GCS credentials.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
Literal['GCS'] |
the storage type. Must be set to "GCS" for this schema. |
credentials |
Optional[str] |
the credentials to use. |
service_account |
Optional[str] |
the service account. |
Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeGSSecretSchema(BaseSecretSchema):
"""KServe GCS credentials.
Attributes:
storage_type: the storage type. Must be set to "GCS" for this schema.
credentials: the credentials to use.
service_account: the service account.
"""
TYPE: ClassVar[str] = KSERVE_GS_SECRET_SCHEMA_TYPE
storage_type: Literal["GCS"] = "GCS"
credentials: Optional[str]
service_account: Optional[str]
KServeS3SecretSchema (BaseSecretSchema)
pydantic-model
KServe S3 credentials.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
Literal['S3'] |
the storage type. Must be set to "s3" for this schema. |
credentials |
Optional[str] |
the credentials to use. |
service_account |
Optional[str] |
the name of the service account. |
s3_endpoint |
Optional[str] |
the S3 endpoint. |
s3_region |
Optional[str] |
the S3 region. |
s3_use_https |
Optional[str] |
whether to use HTTPS. |
s3_verify_ssl |
Optional[str] |
whether to verify SSL. |
Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeS3SecretSchema(BaseSecretSchema):
"""KServe S3 credentials.
Attributes:
storage_type: the storage type. Must be set to "s3" for this schema.
credentials: the credentials to use.
service_account: the name of the service account.
s3_endpoint: the S3 endpoint.
s3_region: the S3 region.
s3_use_https: whether to use HTTPS.
s3_verify_ssl: whether to verify SSL.
"""
TYPE: ClassVar[str] = KSERVE_S3_SECRET_SCHEMA_TYPE
storage_type: Literal["S3"] = "S3"
credentials: Optional[str]
service_account: Optional[str]
s3_endpoint: Optional[str]
s3_region: Optional[str]
s3_use_https: Optional[str]
s3_verify_ssl: Optional[str]
services
special
Initialization for KServe services.
kserve_deployment
Implementation for the KServe inference service.
KServeDeploymentConfig (ServiceConfig)
pydantic-model
KServe deployment service configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the model (or models) to serve. |
model_name |
str |
the name of the model. Multiple versions of the same model should use the same model name. |
predictor |
str |
the KServe predictor used to serve the model. |
replicas |
int |
number of replicas to use for the prediction service. |
resources |
Optional[Dict[str, Any]] |
the Kubernetes resources to allocate for the prediction service. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
class KServeDeploymentConfig(ServiceConfig):
"""KServe deployment service configuration.
Attributes:
model_uri: URI of the model (or models) to serve.
model_name: the name of the model. Multiple versions of the same model
should use the same model name.
predictor: the KServe predictor used to serve the model.
replicas: number of replicas to use for the prediction service.
resources: the Kubernetes resources to allocate for the prediction service.
"""
model_uri: str = ""
model_name: str
secret_name: Optional[str]
predictor: str
replicas: int = 1
resources: Optional[Dict[str, Any]]
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: The labels to sanitize.
"""
# TODO[MEDIUM]: Move k8s label sanitization to a common module with all K8s utils.
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
def get_kubernetes_labels(self) -> Dict[str, str]:
"""Generate the labels for the KServe inference CRD from the service configuration.
These labels are attached to the KServe inference service CRD
and may be used as label selectors in lookup operations.
Returns:
The labels for the KServe inference service CRD.
"""
labels = {"app": "zenml"}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.predictor:
labels["zenml.model_type"] = self.predictor
self.sanitize_labels(labels)
return labels
def get_kubernetes_annotations(self) -> Dict[str, str]:
"""Generate the annotations for the KServe inference CRD the service configuration.
The annotations are used to store additional information about the
KServe ZenML service associated with the deployment that is
not available on the labels. One annotation is particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote KServe inference
service CRD.
Returns:
The annotations for the KServe inference service CRD.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentConfig":
"""Recreate a KServe service from a KServe deployment resource.
Args:
deployment: the KServe inference service CRD.
Returns:
The KServe ZenML service configuration corresponding to the given
KServe inference service CRD.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe ZenML service configuration.
"""
config_data = deployment.metadata.get("annotations").get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded KServe Inference Service resource contains an "
f"invalid or incompatible KServe ZenML service configuration: "
f"{config_data}"
) from e
return service_config
create_from_deployment(deployment)
classmethod
Recreate a KServe service from a KServe deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
V1beta1InferenceService |
the KServe inference service CRD. |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentConfig |
The KServe ZenML service configuration corresponding to the given KServe inference service CRD. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible KServe ZenML service configuration. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentConfig":
"""Recreate a KServe service from a KServe deployment resource.
Args:
deployment: the KServe inference service CRD.
Returns:
The KServe ZenML service configuration corresponding to the given
KServe inference service CRD.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe ZenML service configuration.
"""
config_data = deployment.metadata.get("annotations").get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded KServe Inference Service resource contains an "
f"invalid or incompatible KServe ZenML service configuration: "
f"{config_data}"
) from e
return service_config
get_kubernetes_annotations(self)
Generate the annotations for the KServe inference CRD the service configuration.
The annotations are used to store additional information about the KServe ZenML service associated with the deployment that is not available on the labels. One annotation is particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote KServe inference service CRD.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The annotations for the KServe inference service CRD. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_kubernetes_annotations(self) -> Dict[str, str]:
"""Generate the annotations for the KServe inference CRD the service configuration.
The annotations are used to store additional information about the
KServe ZenML service associated with the deployment that is
not available on the labels. One annotation is particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote KServe inference
service CRD.
Returns:
The annotations for the KServe inference service CRD.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
get_kubernetes_labels(self)
Generate the labels for the KServe inference CRD from the service configuration.
These labels are attached to the KServe inference service CRD and may be used as label selectors in lookup operations.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The labels for the KServe inference service CRD. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_kubernetes_labels(self) -> Dict[str, str]:
"""Generate the labels for the KServe inference CRD from the service configuration.
These labels are attached to the KServe inference service CRD
and may be used as label selectors in lookup operations.
Returns:
The labels for the KServe inference service CRD.
"""
labels = {"app": "zenml"}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.predictor:
labels["zenml.model_type"] = self.predictor
self.sanitize_labels(labels)
return labels
sanitize_labels(labels)
staticmethod
Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
The labels to sanitize. |
required |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: The labels to sanitize.
"""
# TODO[MEDIUM]: Move k8s label sanitization to a common module with all K8s utils.
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
KServeDeploymentService (BaseService)
pydantic-model
A ZenML service that represents a KServe inference service CRD.
Attributes:
Name | Type | Description |
---|---|---|
config |
KServeDeploymentConfig |
service configuration. |
status |
ServiceStatus |
service status. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
class KServeDeploymentService(BaseService):
"""A ZenML service that represents a KServe inference service CRD.
Attributes:
config: service configuration.
status: service status.
"""
SERVICE_TYPE = ServiceType(
name="kserve-deployment",
type="model-serving",
flavor="kserve",
description="KServe inference service",
)
config: KServeDeploymentConfig = Field(
default_factory=KServeDeploymentConfig
)
status: ServiceStatus = Field(default_factory=ServiceStatus)
def _get_model_deployer(self) -> "KServeModelDeployer":
"""Get the active KServe model deployer.
Returns:
The active KServeModelDeployer.
Raises:
TypeError: if the current stack has no KServeModelDeployer.
"""
from zenml.integrations.kserve.model_deployers.kserve_model_deployer import (
KServeModelDeployer,
)
try:
model_deployer = KServeModelDeployer.get_active_model_deployer()
except TypeError:
raise TypeError(
"No active KServe model deployer is present in the active "
"stack. Please make sure that a KServe model deployer is "
"present in the active stack."
)
return model_deployer
def _get_client(self) -> KServeClient:
"""Get the KServe client from the active KServe model deployer.
Returns:
The KServe client.
"""
return self._get_model_deployer().kserve_client
def _get_namespace(self) -> Optional[str]:
"""Get the Kubernetes namespace from the active KServe model deployer.
Returns:
The Kubernetes namespace, or None, if the default namespace is
used.
"""
return self._get_model_deployer().kubernetes_namespace
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the state of the KServe inference service.
This method Checks the current operational state of the external KServe
inference service and translate it into a `ServiceState` value and a printable message.
This method should be overridden by subclasses that implement concrete service tracking functionality.
Returns:
The operational state of the external service and a message
providing additional information about that state (e.g. a
description of the error if one is encountered while checking the
service status).
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
try:
deployment = client.get(name=name, namespace=namespace)
except RuntimeError:
return (ServiceState.INACTIVE, "")
# TODO[MEDIUM]: Implement better operational status checking that also
# cover errors
if "status" not in deployment:
return (ServiceState.INACTIVE, "No operational status available")
status = "Unknown"
for condition in deployment["status"].get("conditions", {}):
if condition.get("type", "") == "PredictorReady":
status = condition.get("status", "Unknown")
if status.lower() == "true":
return (
ServiceState.ACTIVE,
f"Inference service '{name}' is available",
)
elif status.lower() == "false":
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' is not available: {condition.get('message', 'Unknown')}",
)
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' still starting up",
)
@property
def crd_name(self) -> str:
"""Get the name of the KServe inference service CRD that uniquely corresponds to this service instance.
Returns:
The name of the KServe inference service CRD.
"""
return (
self._get_kubernetes_labels().get("zenml.model_name")
or f"zenml-{str(self.uuid)[:8]}"
)
def _get_kubernetes_labels(self) -> Dict[str, str]:
"""Generate the labels for the KServe inference service CRD from the service configuration.
Returns:
The labels for the KServe inference service.
"""
labels = self.config.get_kubernetes_labels()
labels["zenml.service_uuid"] = str(self.uuid)
KServeDeploymentConfig.sanitize_labels(labels)
return labels
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentService":
"""Recreate the configuration of a KServe Service from a deployed instance.
Args:
deployment: the KServe deployment resource.
Returns:
The KServe service configuration corresponding to the given
KServe deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe service configuration.
"""
config = KServeDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.get("labels").get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
def provision(self) -> None:
"""Provision or update remote KServe deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
namespace = self._get_namespace()
api_version = constants.KSERVE_GROUP + "/" + "v1beta1"
name = self.crd_name
# All supported model specs seem to have the same fields
# so we can use any one of them (see https://kserve.github.io/website/0.8/reference/api/#serving.kserve.io/v1beta1.PredictorExtensionSpec)
predictor_kwargs = {
self.config.predictor: V1beta1PredictorExtensionSpec(
storage_uri=self.config.model_uri,
resources=self.config.resources,
)
}
isvc = V1beta1InferenceService(
api_version=api_version,
kind=constants.KSERVE_KIND,
metadata=k8s_client.V1ObjectMeta(
name=name,
namespace=namespace,
labels=self._get_kubernetes_labels(),
annotations=self.config.get_kubernetes_annotations(),
),
spec=V1beta1InferenceServiceSpec(
predictor=V1beta1PredictorSpec(**predictor_kwargs)
),
)
# TODO[HIGH]: better error handling when provisioning KServe instances
try:
client.get(name=name, namespace=namespace)
# update the existing deployment
client.replace(name, isvc, namespace=namespace)
except RuntimeError:
client.create(isvc)
def deprovision(self, force: bool = False) -> None:
"""Deprovisions all resources used by the service.
Args:
force: if True, the service will be deprovisioned even if it is
still in use.
Raises:
ValueError: if the service is still in use and force is False.
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
# TODO[HIGH]: catch errors if deleting a KServe instance that is no
# longer available
try:
client.delete(name=name, namespace=namespace)
except RuntimeError:
raise ValueError(
f"Could not delete KServe instance '{name}' from namespace: '{namespace}'."
)
def _get_deployment_logs(
self,
name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a KServe deployment resource.
Args:
name: the name of the KServe deployment to get logs for.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Raises:
Exception: if an unknown error occurs while fetching the logs.
Yields:
The logs of the given deployment.
"""
client = self._get_client()
namespace = self._get_namespace()
logger.debug(f"Retrieving logs for InferenceService resource: {name}")
try:
response = client.core_api.list_namespaced_pod(
namespace=namespace,
label_selector=f"zenml.service_uuid={self.uuid}",
)
logger.debug("Kubernetes API response: %s", response)
pods = response.items
if not pods:
raise Exception(
f"The KServe deployment {name} is not currently "
f"running: no Kubernetes pods associated with it were found"
)
pod = pods[0]
pod_name = pod.metadata.name
containers = [c.name for c in pod.spec.containers]
init_containers = [c.name for c in pod.spec.init_containers]
container_statuses = {
c.name: c.started or c.restart_count
for c in pod.status.container_statuses
}
container = "default"
if container not in containers:
container = containers[0]
if not container_statuses[container]:
container = init_containers[0]
logger.info(
f"Retrieving logs for pod: `{pod_name}` and container "
f"`{container}` in namespace `{namespace}`"
)
response = client.core_api.read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
container=container,
follow=follow,
tail_lines=tail,
_preload_content=False,
)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when fetching logs for InferenceService resource "
"%s: %s",
name,
str(e),
)
raise Exception(
f"Unexpected exception when fetching logs for InferenceService "
f"resource: {name}"
) from e
try:
while True:
line = response.readline().decode("utf-8").rstrip("\n")
if not line:
return
stop = yield line
if stop:
return
finally:
response.release_conn()
def get_logs(
self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
"""Retrieve the logs from the remote KServe inference service instance.
Args:
follow: if True, the logs will be streamed as they are written.
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_deployment_logs(
self.crd_name,
follow=follow,
tail=tail,
)
@property
def prediction_url(self) -> Optional[str]:
"""The prediction URI exposed by the prediction service.
Returns:
The prediction URI exposed by the prediction service, or None if
the service is not yet ready.
"""
if not self.is_running:
return None
model_deployer = self._get_model_deployer()
return os.path.join(
model_deployer.base_url,
"v1/models",
f"{self.crd_name}:predict",
)
@property
def prediction_hostname(self) -> Optional[str]:
"""The prediction hostname exposed by the prediction service.
Returns:
The prediction hostname exposed by the prediction service status
that will be used in the headers of the prediction request.
"""
if not self.is_running:
return None
namespace = self._get_namespace()
model_deployer = self._get_model_deployer()
custom_domain = model_deployer.custom_domain or "example.com"
return f"{self.crd_name}.{namespace}.{custom_domain}"
def predict(self, request: str) -> Any:
"""Make a prediction using the service.
Args:
request: a NumPy array representing the request
Returns:
A NumPy array represents the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"KServe prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
if self.prediction_hostname is None:
raise ValueError(
"`self.prediction_hostname` is not set, cannot post."
)
headers = {"Host": self.prediction_hostname}
if isinstance(request, str):
request = json.loads(request)
else:
raise ValueError("Request must be a json string.")
response = requests.post(
self.prediction_url,
headers=headers,
json={"instances": request},
)
response.raise_for_status()
return response.json()["predictions"]
crd_name: str
property
readonly
Get the name of the KServe inference service CRD that uniquely corresponds to this service instance.
Returns:
Type | Description |
---|---|
str |
The name of the KServe inference service CRD. |
prediction_hostname: Optional[str]
property
readonly
The prediction hostname exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction hostname exposed by the prediction service status that will be used in the headers of the prediction request. |
prediction_url: Optional[str]
property
readonly
The prediction URI exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction URI exposed by the prediction service, or None if the service is not yet ready. |
check_status(self)
Check the state of the KServe inference service.
This method Checks the current operational state of the external KServe
inference service and translate it into a ServiceState
value and a printable message.
This method should be overridden by subclasses that implement concrete service tracking functionality.
Returns:
Type | Description |
---|---|
Tuple[zenml.services.service_status.ServiceState, str] |
The operational state of the external service and a message providing additional information about that state (e.g. a description of the error if one is encountered while checking the service status). |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the state of the KServe inference service.
This method Checks the current operational state of the external KServe
inference service and translate it into a `ServiceState` value and a printable message.
This method should be overridden by subclasses that implement concrete service tracking functionality.
Returns:
The operational state of the external service and a message
providing additional information about that state (e.g. a
description of the error if one is encountered while checking the
service status).
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
try:
deployment = client.get(name=name, namespace=namespace)
except RuntimeError:
return (ServiceState.INACTIVE, "")
# TODO[MEDIUM]: Implement better operational status checking that also
# cover errors
if "status" not in deployment:
return (ServiceState.INACTIVE, "No operational status available")
status = "Unknown"
for condition in deployment["status"].get("conditions", {}):
if condition.get("type", "") == "PredictorReady":
status = condition.get("status", "Unknown")
if status.lower() == "true":
return (
ServiceState.ACTIVE,
f"Inference service '{name}' is available",
)
elif status.lower() == "false":
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' is not available: {condition.get('message', 'Unknown')}",
)
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' still starting up",
)
create_from_deployment(deployment)
classmethod
Recreate the configuration of a KServe Service from a deployed instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
V1beta1InferenceService |
the KServe deployment resource. |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentService |
The KServe service configuration corresponding to the given KServe deployment resource. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible KServe service configuration. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentService":
"""Recreate the configuration of a KServe Service from a deployed instance.
Args:
deployment: the KServe deployment resource.
Returns:
The KServe service configuration corresponding to the given
KServe deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe service configuration.
"""
config = KServeDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.get("labels").get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
deprovision(self, force=False)
Deprovisions all resources used by the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
force |
bool |
if True, the service will be deprovisioned even if it is still in use. |
False |
Exceptions:
Type | Description |
---|---|
ValueError |
if the service is still in use and force is False. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def deprovision(self, force: bool = False) -> None:
"""Deprovisions all resources used by the service.
Args:
force: if True, the service will be deprovisioned even if it is
still in use.
Raises:
ValueError: if the service is still in use and force is False.
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
# TODO[HIGH]: catch errors if deleting a KServe instance that is no
# longer available
try:
client.delete(name=name, namespace=namespace)
except RuntimeError:
raise ValueError(
f"Could not delete KServe instance '{name}' from namespace: '{namespace}'."
)
get_logs(self, follow=False, tail=None)
Retrieve the logs from the remote KServe inference service instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
follow |
bool |
if True, the logs will be streamed as they are written. |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_logs(
self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
"""Retrieve the logs from the remote KServe inference service instance.
Args:
follow: if True, the logs will be streamed as they are written.
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_deployment_logs(
self.crd_name,
follow=follow,
tail=tail,
)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
str |
a NumPy array representing the request |
required |
Returns:
Type | Description |
---|---|
Any |
A NumPy array represents the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not yet ready. |
ValueError |
if the prediction_url is not set. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def predict(self, request: str) -> Any:
"""Make a prediction using the service.
Args:
request: a NumPy array representing the request
Returns:
A NumPy array represents the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"KServe prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
if self.prediction_hostname is None:
raise ValueError(
"`self.prediction_hostname` is not set, cannot post."
)
headers = {"Host": self.prediction_hostname}
if isinstance(request, str):
request = json.loads(request)
else:
raise ValueError("Request must be a json string.")
response = requests.post(
self.prediction_url,
headers=headers,
json={"instances": request},
)
response.raise_for_status()
return response.json()["predictions"]
provision(self)
Provision or update remote KServe deployment instance.
This should then match the current configuration.
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def provision(self) -> None:
"""Provision or update remote KServe deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
namespace = self._get_namespace()
api_version = constants.KSERVE_GROUP + "/" + "v1beta1"
name = self.crd_name
# All supported model specs seem to have the same fields
# so we can use any one of them (see https://kserve.github.io/website/0.8/reference/api/#serving.kserve.io/v1beta1.PredictorExtensionSpec)
predictor_kwargs = {
self.config.predictor: V1beta1PredictorExtensionSpec(
storage_uri=self.config.model_uri,
resources=self.config.resources,
)
}
isvc = V1beta1InferenceService(
api_version=api_version,
kind=constants.KSERVE_KIND,
metadata=k8s_client.V1ObjectMeta(
name=name,
namespace=namespace,
labels=self._get_kubernetes_labels(),
annotations=self.config.get_kubernetes_annotations(),
),
spec=V1beta1InferenceServiceSpec(
predictor=V1beta1PredictorSpec(**predictor_kwargs)
),
)
# TODO[HIGH]: better error handling when provisioning KServe instances
try:
client.get(name=name, namespace=namespace)
# update the existing deployment
client.replace(name, isvc, namespace=namespace)
except RuntimeError:
client.create(isvc)
steps
special
Initialization for KServe steps.
kserve_deployer
Implementation of the KServe Deployer step.
KServeDeployerStepConfig (BaseStepConfig)
pydantic-model
KServe model deployer step configuration.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
KServeDeploymentConfig |
KServe deployment service configuration. |
torch_serve_params |
TorchServe set of parameters to deploy model. |
|
timeout |
int |
Timeout for model deployment. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepConfig(BaseStepConfig):
"""KServe model deployer step configuration.
Attributes:
service_config: KServe deployment service configuration.
torch_serve_params: TorchServe set of parameters to deploy model.
timeout: Timeout for model deployment.
"""
service_config: KServeDeploymentConfig
torch_serve_parameters: Optional[TorchServeParameters] = None
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
TorchServeParameters (BaseModel)
pydantic-model
KServe PyTorch model deployer step configuration.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
KServe deployment service configuration. |
|
model_class |
str |
Path to Python file containing model architecture. |
handler |
str |
TorchServe's handler file to handle custom TorchServe inference logic. |
extra_files |
Optional[List[str]] |
Comma separated path to extra dependency files. |
model_version |
Optional[str] |
Model version. |
requirements_file |
Optional[str] |
Path to requirements file. |
torch_config |
Optional[str] |
TorchServe configuration file path. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class TorchServeParameters(BaseModel):
"""KServe PyTorch model deployer step configuration.
Attributes:
service_config: KServe deployment service configuration.
model_class: Path to Python file containing model architecture.
handler: TorchServe's handler file to handle custom TorchServe inference logic.
extra_files: Comma separated path to extra dependency files.
model_version: Model version.
requirements_file: Path to requirements file.
torch_config: TorchServe configuration file path.
"""
model_class: str
handler: str
extra_files: Optional[List[str]] = None
requirements_file: Optional[str] = None
model_version: Optional[str] = "1.0"
torch_config: Optional[str] = None
@validator("model_class")
def model_class_validate(cls, v: str) -> str:
"""Validate model class file path.
Args:
v: model class file path
Returns:
model class file path
Raises:
ValueError: if model class file path is not valid
"""
if not v:
raise ValueError("Model class file path is required.")
if not is_inside_repository(v):
raise ValueError(
"Model class file path must be inside the repository."
)
return v
@validator("handler")
def handler_validate(cls, v: str) -> str:
"""Validate handler.
Args:
v: handler file path
Returns:
handler file path
Raises:
ValueError: if handler file path is not valid
"""
if v:
if v in TORCH_HANDLERS:
return v
elif is_inside_repository(v):
return v
else:
raise ValueError(
"Handler must be one of the TorchServe handlers",
"or a file that exists inside the repository.",
)
else:
raise ValueError("Handler is required.")
@validator("extra_files")
def extra_files_validate(
cls, v: Optional[List[str]]
) -> Optional[List[str]]:
"""Validate extra files.
Args:
v: extra files path
Returns:
extra files path
Raises:
ValueError: if the extra files path is not valid
"""
extra_files = []
if v is not None:
for file_path in v:
if is_inside_repository(file_path):
extra_files.append(file_path)
else:
raise ValueError(
"Extra file path must be inside the repository."
)
return extra_files
return v
@validator("torch_config")
def torch_config_validate(cls, v: Optional[str]) -> Optional[str]:
"""Validate torch config file.
Args:
v: torch config file path
Returns:
torch config file path
Raises:
ValueError: if torch config file path is not valid.
"""
if v:
if is_inside_repository(v):
return v
else:
raise ValueError(
"Torch config file path must be inside the repository."
)
return v
extra_files_validate(v)
classmethod
Validate extra files.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
Optional[List[str]] |
extra files path |
required |
Returns:
Type | Description |
---|---|
Optional[List[str]] |
extra files path |
Exceptions:
Type | Description |
---|---|
ValueError |
if the extra files path is not valid |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("extra_files")
def extra_files_validate(
cls, v: Optional[List[str]]
) -> Optional[List[str]]:
"""Validate extra files.
Args:
v: extra files path
Returns:
extra files path
Raises:
ValueError: if the extra files path is not valid
"""
extra_files = []
if v is not None:
for file_path in v:
if is_inside_repository(file_path):
extra_files.append(file_path)
else:
raise ValueError(
"Extra file path must be inside the repository."
)
return extra_files
return v
handler_validate(v)
classmethod
Validate handler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
str |
handler file path |
required |
Returns:
Type | Description |
---|---|
str |
handler file path |
Exceptions:
Type | Description |
---|---|
ValueError |
if handler file path is not valid |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("handler")
def handler_validate(cls, v: str) -> str:
"""Validate handler.
Args:
v: handler file path
Returns:
handler file path
Raises:
ValueError: if handler file path is not valid
"""
if v:
if v in TORCH_HANDLERS:
return v
elif is_inside_repository(v):
return v
else:
raise ValueError(
"Handler must be one of the TorchServe handlers",
"or a file that exists inside the repository.",
)
else:
raise ValueError("Handler is required.")
model_class_validate(v)
classmethod
Validate model class file path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
str |
model class file path |
required |
Returns:
Type | Description |
---|---|
str |
model class file path |
Exceptions:
Type | Description |
---|---|
ValueError |
if model class file path is not valid |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("model_class")
def model_class_validate(cls, v: str) -> str:
"""Validate model class file path.
Args:
v: model class file path
Returns:
model class file path
Raises:
ValueError: if model class file path is not valid
"""
if not v:
raise ValueError("Model class file path is required.")
if not is_inside_repository(v):
raise ValueError(
"Model class file path must be inside the repository."
)
return v
torch_config_validate(v)
classmethod
Validate torch config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
Optional[str] |
torch config file path |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
torch config file path |
Exceptions:
Type | Description |
---|---|
ValueError |
if torch config file path is not valid. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("torch_config")
def torch_config_validate(cls, v: Optional[str]) -> Optional[str]:
"""Validate torch config file.
Args:
v: torch config file path
Returns:
torch config file path
Raises:
ValueError: if torch config file path is not valid.
"""
if v:
if is_inside_repository(v):
return v
else:
raise ValueError(
"Torch config file path must be inside the repository."
)
return v
kserve_model_deployer_step (BaseStep)
KServe model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for an ML model with KServe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
config |
configuration for the deployer step |
required | |
model |
the model artifact to deploy |
required | |
context |
the step context |
required |
Returns:
Type | Description |
---|---|
KServe deployment service |
CONFIG_CLASS (BaseStepConfig)
pydantic-model
KServe model deployer step configuration.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
KServeDeploymentConfig |
KServe deployment service configuration. |
torch_serve_params |
TorchServe set of parameters to deploy model. |
|
timeout |
int |
Timeout for model deployment. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepConfig(BaseStepConfig):
"""KServe model deployer step configuration.
Attributes:
service_config: KServe deployment service configuration.
torch_serve_params: TorchServe set of parameters to deploy model.
timeout: Timeout for model deployment.
"""
service_config: KServeDeploymentConfig
torch_serve_parameters: Optional[TorchServeParameters] = None
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, config, context, model)
staticmethod
KServe model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for an ML model with KServe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
config |
KServeDeployerStepConfig |
configuration for the deployer step |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
context |
StepContext |
the step context |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentService |
KServe deployment service |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@step(enable_cache=False)
def kserve_model_deployer_step(
deploy_decision: bool,
config: KServeDeployerStepConfig,
context: StepContext,
model: ModelArtifact,
) -> KServeDeploymentService:
"""KServe model deployer pipeline step.
This step can be used in a pipeline to implement continuous
deployment for an ML model with KServe.
Args:
deploy_decision: whether to deploy the model or not
config: configuration for the deployer step
model: the model artifact to deploy
context: the step context
Returns:
KServe deployment service
"""
model_deployer = KServeModelDeployer.get_active_model_deployer()
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
pipeline_run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# update the step configuration with the real pipeline runtime information
config.service_config.pipeline_name = pipeline_name
config.service_config.pipeline_run_id = pipeline_run_id
config.service_config.pipeline_step_name = step_name
# fetch existing services with same pipeline name, step name and
# model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=config.service_config.model_name,
)
# even when the deploy decision is negative if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing the last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{config.service_config.model_name}'..."
)
service = cast(KServeDeploymentService, existing_services[0])
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure that
# a model server is available at all times
if not service.is_running:
service.start(timeout=config.timeout)
return service
# invoke the KServe model deployer to create a new service
# or update an existing one that was previously deployed for the same
# model
if config.service_config.predictor == "pytorch":
# import the prepare function from the step utils
from zenml.integrations.kserve.steps.kserve_step_utils import (
prepare_torch_service_config,
)
# prepare the service config
service_config = prepare_torch_service_config(
model_uri=model.uri,
output_artifact_uri=context.get_output_artifact_uri(),
config=config,
)
else:
# import the prepare function from the step utils
from zenml.integrations.kserve.steps.kserve_step_utils import (
prepare_service_config,
)
# prepare the service config
service_config = prepare_service_config(
model_uri=model.uri,
output_artifact_uri=context.get_output_artifact_uri(),
config=config,
)
service = cast(
KServeDeploymentService,
model_deployer.deploy_model(
service_config, replace=True, timeout=config.timeout
),
)
logger.info(
f"KServe deployment service started and reachable at:\n"
f" {service.prediction_url}\n"
f" With the hostname: {service.prediction_hostname}."
)
return service
kserve_step_utils
This module contains the utility functions used by the KServe deployer step.
TorchModelArchiver (BaseModel)
pydantic-model
Model Archiver for PyTorch models.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
Model name. |
model_version |
Model version. |
|
serialized_file |
str |
Serialized model file. |
handler |
str |
TorchServe's handler file to handle custom TorchServe inference logic. |
extra_files |
Optional[List[str]] |
Comma separated path to extra dependency files. |
requirements_file |
Optional[str] |
Path to requirements file. |
export_path |
str |
Path to export model. |
runtime |
Optional[str] |
Runtime of the model. |
force |
Optional[bool] |
Force export of the model. |
archive_format |
Optional[str] |
Archive format. |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
class TorchModelArchiver(BaseModel):
"""Model Archiver for PyTorch models.
Attributes:
model_name: Model name.
model_version: Model version.
serialized_file: Serialized model file.
handler: TorchServe's handler file to handle custom TorchServe inference logic.
extra_files: Comma separated path to extra dependency files.
requirements_file: Path to requirements file.
export_path: Path to export model.
runtime: Runtime of the model.
force: Force export of the model.
archive_format: Archive format.
"""
model_name: str
serialized_file: str
model_file: str
handler: str
export_path: str
extra_files: Optional[List[str]] = None
version: Optional[str] = None
requirements_file: Optional[str] = None
runtime: Optional[str] = "python"
force: Optional[bool] = None
archive_format: Optional[str] = "default"
generate_model_deployer_config(model_name, directory)
Generate a model deployer config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
the name of the model |
required |
directory |
str |
the directory where the model is stored |
required |
Returns:
Type | Description |
---|---|
str |
None |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def generate_model_deployer_config(
model_name: str,
directory: str,
) -> str:
"""Generate a model deployer config.
Args:
model_name: the name of the model
directory: the directory where the model is stored
Returns:
None
"""
config_lines = [
"inference_address=http://0.0.0.0:8085",
"management_address=http://0.0.0.0:8085",
"metrics_address=http://0.0.0.0:8082",
"grpc_inference_port=7070",
"grpc_management_port=7071",
"enable_metrics_api=true",
"metrics_format=prometheus",
"number_of_netty_threads=4",
"job_queue_size=10",
"enable_envvars_config=true",
"install_py_dep_per_model=true",
"model_store=/mnt/models/model-store",
]
with tempfile.NamedTemporaryFile(
suffix=".properties", mode="w+", dir=directory, delete=False
) as f:
for line in config_lines:
f.write(line + "\n")
f.write(
f'model_snapshot={{"name":"startup.cfg","modelCount":1,"models":{{"{model_name}":{{"1.0":{{"defaultVersion":true,"marName":"{model_name}.mar","minWorkers":1,"maxWorkers":5,"batchSize":1,"maxBatchDelay":10,"responseTimeout":120}}}}}}}}'
)
f.close()
return f.name
prepare_service_config(model_uri, output_artifact_uri, config)
Prepare the model files for model serving.
This function ensures that the model files are in the correct format and file structure required by the KServe server implementation used for model serving.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the URI of the model artifact being served |
required |
output_artifact_uri |
str |
the URI of the output artifact |
required |
config |
KServeDeployerStepConfig |
the KServe deployer step config |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentConfig |
The URL to the model is ready for serving. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the model files cannot be prepared. |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def prepare_service_config(
model_uri: str, output_artifact_uri: str, config: KServeDeployerStepConfig
) -> KServeDeploymentConfig:
"""Prepare the model files for model serving.
This function ensures that the model files are in the correct format
and file structure required by the KServe server implementation
used for model serving.
Args:
model_uri: the URI of the model artifact being served
output_artifact_uri: the URI of the output artifact
config: the KServe deployer step config
Returns:
The URL to the model is ready for serving.
Raises:
RuntimeError: if the model files cannot be prepared.
"""
served_model_uri = os.path.join(output_artifact_uri, "kserve")
fileio.makedirs(served_model_uri)
# TODO [ENG-773]: determine how to formalize how models are organized into
# folders and sub-folders depending on the model type/format and the
# KServe protocol used to serve the model.
# TODO [ENG-791]: an auto-detect built-in KServe server implementation
# from the model artifact type
# TODO [ENG-792]: validate the model artifact type against the
# supported built-in KServe server implementations
if config.service_config.predictor == "tensorflow":
# the TensorFlow server expects model artifacts to be
# stored in numbered subdirectories, each representing a model
# version
served_model_uri = os.path.join(
served_model_uri,
config.service_config.predictor,
config.service_config.model_name,
)
fileio.makedirs(served_model_uri)
io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
elif config.service_config.predictor == "sklearn":
# the sklearn server expects model artifacts to be
# stored in a file called model.joblib
model_uri = os.path.join(model_uri, "model")
if not fileio.exists(model_uri):
raise RuntimeError(
f"Expected sklearn model artifact was not found at "
f"{model_uri}"
)
served_model_uri = os.path.join(
served_model_uri,
config.service_config.predictor,
config.service_config.model_name,
)
fileio.makedirs(served_model_uri)
fileio.copy(model_uri, os.path.join(served_model_uri, "model.joblib"))
else:
# default treatment for all other server implementations is to
# simply reuse the model from the artifact store path where it
# is originally stored
served_model_uri = os.path.join(
served_model_uri,
config.service_config.predictor,
config.service_config.model_name,
)
fileio.makedirs(served_model_uri)
fileio.copy(model_uri, served_model_uri)
service_config = config.service_config.copy()
service_config.model_uri = served_model_uri
return service_config
prepare_torch_service_config(model_uri, output_artifact_uri, config)
Prepare the PyTorch model files for model serving.
This function ensures that the model files are in the correct format and file structure required by the KServe server implementation used for model serving.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the URI of the model artifact being served |
required |
output_artifact_uri |
str |
the URI of the output artifact |
required |
config |
KServeDeployerStepConfig |
the KServe deployer step config |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentConfig |
The URL to the model is ready for serving. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the model files cannot be prepared. |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def prepare_torch_service_config(
model_uri: str, output_artifact_uri: str, config: KServeDeployerStepConfig
) -> KServeDeploymentConfig:
"""Prepare the PyTorch model files for model serving.
This function ensures that the model files are in the correct format
and file structure required by the KServe server implementation
used for model serving.
Args:
model_uri: the URI of the model artifact being served
output_artifact_uri: the URI of the output artifact
config: the KServe deployer step config
Returns:
The URL to the model is ready for serving.
Raises:
RuntimeError: if the model files cannot be prepared.
"""
deployment_folder_uri = os.path.join(output_artifact_uri, "kserve")
served_model_uri = os.path.join(deployment_folder_uri, "model-store")
config_propreties_uri = os.path.join(deployment_folder_uri, "config")
fileio.makedirs(served_model_uri)
fileio.makedirs(config_propreties_uri)
if config.torch_serve_parameters is None:
raise RuntimeError("No torch serve parameters provided")
else:
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-pytorch-temp-")
tmp_model_uri = os.path.join(
str(temp_dir), f"{config.service_config.model_name}.pt"
)
# Copy from artifact store to temporary file
fileio.copy(f"{model_uri}/checkpoint.pt", tmp_model_uri)
torch_archiver_args = TorchModelArchiver(
model_name=config.service_config.model_name,
serialized_file=tmp_model_uri,
model_file=config.torch_serve_parameters.model_class,
handler=config.torch_serve_parameters.handler,
export_path=temp_dir,
version=config.torch_serve_parameters.model_version,
)
manifest = ModelExportUtils.generate_manifest_json(torch_archiver_args)
package_model(torch_archiver_args, manifest=manifest)
# Copy from temporary file to artifact store
archived_model_uri = os.path.join(
temp_dir, f"{config.service_config.model_name}.mar"
)
if not fileio.exists(archived_model_uri):
raise RuntimeError(
f"Expected torch archived model artifact was not found at "
f"{archived_model_uri}"
)
# Copy the torch model archive artifact to the model store
fileio.copy(
archived_model_uri,
os.path.join(
served_model_uri, f"{config.service_config.model_name}.mar"
),
)
# Get or Generate the config file
if config.torch_serve_parameters.torch_config:
# Copy the torch model config to the model store
fileio.copy(
config.torch_serve_parameters.torch_config,
os.path.join(config_propreties_uri, "config.properties"),
)
else:
# Generate the config file
config_file_uri = generate_model_deployer_config(
model_name=config.service_config.model_name,
directory=temp_dir,
)
# Copy the torch model config to the model store
fileio.copy(
config_file_uri,
os.path.join(config_propreties_uri, "config.properties"),
)
service_config = config.service_config.copy()
service_config.model_uri = deployment_folder_uri
return service_config
kubeflow
special
Initialization of the Kubeflow integration for ZenML.
The Kubeflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool.
KubeflowIntegration (Integration)
Definition of Kubeflow Integration for ZenML.
Source code in zenml/integrations/kubeflow/__init__.py
class KubeflowIntegration(Integration):
"""Definition of Kubeflow Integration for ZenML."""
NAME = KUBEFLOW
REQUIREMENTS = ["kfp==1.8.9"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Kubeflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=KUBEFLOW_METADATA_STORE_FLAVOR,
source="zenml.integrations.kubeflow.metadata_stores.KubeflowMetadataStore",
type=StackComponentType.METADATA_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=KUBEFLOW_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.kubeflow.orchestrators.KubeflowOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the Kubeflow integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/kubeflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Kubeflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=KUBEFLOW_METADATA_STORE_FLAVOR,
source="zenml.integrations.kubeflow.metadata_stores.KubeflowMetadataStore",
type=StackComponentType.METADATA_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=KUBEFLOW_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.kubeflow.orchestrators.KubeflowOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
]
metadata_stores
special
Initialization of the Kubeflow metadata store for ZenML.
kubeflow_metadata_store
Implementation of the Kubeflow metadata store.
KubeflowMetadataStore (BaseMetadataStore)
pydantic-model
Kubeflow GRPC backend for ZenML metadata store.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
class KubeflowMetadataStore(BaseMetadataStore):
"""Kubeflow GRPC backend for ZenML metadata store."""
upgrade_migration_enabled: bool = False
host: str = "127.0.0.1"
port: int = DEFAULT_KFP_METADATA_GRPC_PORT
# Class Configuration
FLAVOR: ClassVar[str] = KUBEFLOW_METADATA_STORE_FLAVOR
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a KFP orchestrator.
Returns:
The stack validator.
"""
def _ensure_kfp_orchestrator(stack: Stack) -> Tuple[bool, str]:
return (
stack.orchestrator.FLAVOR == KUBEFLOW,
"The Kubeflow metadata store can only be used with a Kubeflow "
"orchestrator.",
)
return StackValidator(
custom_validation_function=_ensure_kfp_orchestrator
)
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for the kubeflow metadata store.
Returns:
The tfx metadata config for the kubeflow metadata store.
Raises:
RuntimeError: If the metadata store is not running.
"""
connection_config = metadata_store_pb2.MetadataStoreClientConfig()
if inside_kfp_pod():
connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
connection_config.port = int(
os.environ["METADATA_GRPC_SERVICE_PORT"]
)
else:
if not self.is_running:
raise RuntimeError(
"The KFP metadata daemon is not running. Please run the "
"following command to start it first:\n\n"
" 'zenml metadata-store up'\n"
)
connection_config.host = self.host
connection_config.port = self.port
return connection_config
@property
def kfp_orchestrator(self) -> KubeflowOrchestrator:
"""Returns the Kubeflow orchestrator in the active stack.
Returns:
The Kubeflow orchestrator in the active stack.
"""
repo = Repository(skip_repository_check=True) # type: ignore[call-arg]
return cast(KubeflowOrchestrator, repo.active_stack.orchestrator)
@property
def kubernetes_context(self) -> str:
"""Returns the kubernetes context.
This is returned to the cluster where the Kubeflow Pipelines services
are running.
Returns:
The kubernetes context.
"""
kubernetes_context = self.kfp_orchestrator.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
return kubernetes_context
@property
def root_directory(self) -> str:
"""Returns path to the root directory.
This is for all files concerning this KFP metadata store.
Note: the root directory for the KFP metadata store is relative to the
root directory of the KFP orchestrator, because it is a sub-component
of it.
Returns:
Path to the root directory.
"""
return os.path.join(
self.kfp_orchestrator.root_directory,
"metadata-store",
str(self.uuid),
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "kubeflow_daemon.pid")
@property
def _log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path to the daemon log file.
"""
return os.path.join(self.root_directory, "kubeflow_daemon.log")
@property
def is_provisioned(self) -> bool:
"""If the component provisioned resources to run locally.
Returns:
True if the component provisioned resources to run locally.
"""
return fileio.exists(self.root_directory)
@property
def is_running(self) -> bool:
"""If the component is running locally.
Returns:
True if the component is running locally, False otherwise.
"""
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
if not check_if_daemon_is_running(self._pid_file_path):
return False
else:
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
return True
def provision(self) -> None:
"""Provisions resources to run the component locally."""
logger.info("Provisioning local Kubeflow Pipelines deployment...")
fileio.makedirs(self.root_directory)
def deprovision(self) -> None:
"""Deprovisions all local resources of the component."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
def resume(self) -> None:
"""Resumes the local k3d cluster."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
self.start_kfp_metadata_daemon()
self.wait_until_metadata_store_ready()
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_running:
logger.info("Local kubeflow pipelines deployment not running.")
return
self.stop_kfp_metadata_daemon()
def start_kfp_metadata_daemon(self) -> None:
"""Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines Metadata MySQL database is accessible
on the localhost.
Raises:
ProvisioningError: if the daemon fails to start.
"""
command = [
"kubectl",
"--context",
self.kubernetes_context,
"--namespace",
"kubeflow",
"port-forward",
"svc/metadata-grpc-service",
f"{self.port}:8080",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubeflow Pipelines Metadata locally, "
"please run '%s' in a separate command line shell.",
self.port,
" ".join(command),
)
elif not networking_utils.port_available(self.port):
raise ProvisioningError(
f"Unable to port-forward Kubeflow Pipelines Metadata to local "
f"port {self.port} because the port is occupied. In order to "
f"access the Kubeflow Pipelines Metadata locally, please "
f"change the metadata store configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Kubeflow Pipelines Metadata daemon (check the daemon"
"logs at %s in case you're not able to access the pipeline"
"metadata).",
self._log_file,
)
def stop_kfp_metadata_daemon(self) -> None:
"""Stops the KFP Metadata daemon process if it is running."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
def wait_until_metadata_store_ready(
self, timeout: int = DEFAULT_KFP_METADATA_DAEMON_TIMEOUT
) -> None:
"""Waits until the metadata store connection is ready.
Potentially an irrecoverable error could occur or the timeout could
expire, so it checks for this.
Args:
timeout: The maximum time to wait for the metadata store to be
ready.
Raises:
RuntimeError: if the metadata store is not ready after the timeout
"""
logger.info(
"Waiting for the Kubeflow metadata store to be ready (this might "
"take a few minutes)."
)
while True:
try:
# it doesn't matter what we call here as long as it exercises
# the MLMD connection
self.get_pipelines()
break
except Exception as e:
logger.info(
"The Kubeflow metadata store is not ready yet. Waiting for "
"10 seconds..."
)
if timeout <= 0:
raise RuntimeError(
f"An unexpected error was encountered while waiting for the "
f"Kubeflow metadata store to be functional: {str(e)}"
) from e
timeout -= 10
time.sleep(10)
logger.info("The Kubeflow metadata store is functional.")
is_provisioned: bool
property
readonly
If the component provisioned resources to run locally.
Returns:
Type | Description |
---|---|
bool |
True if the component provisioned resources to run locally. |
is_running: bool
property
readonly
If the component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if the component is running locally, False otherwise. |
kfp_orchestrator: KubeflowOrchestrator
property
readonly
Returns the Kubeflow orchestrator in the active stack.
Returns:
Type | Description |
---|---|
KubeflowOrchestrator |
The Kubeflow orchestrator in the active stack. |
kubernetes_context: str
property
readonly
Returns the kubernetes context.
This is returned to the cluster where the Kubeflow Pipelines services are running.
Returns:
Type | Description |
---|---|
str |
The kubernetes context. |
root_directory: str
property
readonly
Returns path to the root directory.
This is for all files concerning this KFP metadata store.
Note: the root directory for the KFP metadata store is relative to the root directory of the KFP orchestrator, because it is a sub-component of it.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a KFP orchestrator.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
The stack validator. |
deprovision(self)
Deprovisions all local resources of the component.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def deprovision(self) -> None:
"""Deprovisions all local resources of the component."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
get_tfx_metadata_config(self)
Return tfx metadata config for the kubeflow metadata store.
Returns:
Type | Description |
---|---|
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig] |
The tfx metadata config for the kubeflow metadata store. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the metadata store is not running. |
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for the kubeflow metadata store.
Returns:
The tfx metadata config for the kubeflow metadata store.
Raises:
RuntimeError: If the metadata store is not running.
"""
connection_config = metadata_store_pb2.MetadataStoreClientConfig()
if inside_kfp_pod():
connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
connection_config.port = int(
os.environ["METADATA_GRPC_SERVICE_PORT"]
)
else:
if not self.is_running:
raise RuntimeError(
"The KFP metadata daemon is not running. Please run the "
"following command to start it first:\n\n"
" 'zenml metadata-store up'\n"
)
connection_config.host = self.host
connection_config.port = self.port
return connection_config
provision(self)
Provisions resources to run the component locally.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def provision(self) -> None:
"""Provisions resources to run the component locally."""
logger.info("Provisioning local Kubeflow Pipelines deployment...")
fileio.makedirs(self.root_directory)
resume(self)
Resumes the local k3d cluster.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def resume(self) -> None:
"""Resumes the local k3d cluster."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
self.start_kfp_metadata_daemon()
self.wait_until_metadata_store_ready()
start_kfp_metadata_daemon(self)
Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines Metadata MySQL database is accessible on the localhost.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
if the daemon fails to start. |
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def start_kfp_metadata_daemon(self) -> None:
"""Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines Metadata MySQL database is accessible
on the localhost.
Raises:
ProvisioningError: if the daemon fails to start.
"""
command = [
"kubectl",
"--context",
self.kubernetes_context,
"--namespace",
"kubeflow",
"port-forward",
"svc/metadata-grpc-service",
f"{self.port}:8080",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubeflow Pipelines Metadata locally, "
"please run '%s' in a separate command line shell.",
self.port,
" ".join(command),
)
elif not networking_utils.port_available(self.port):
raise ProvisioningError(
f"Unable to port-forward Kubeflow Pipelines Metadata to local "
f"port {self.port} because the port is occupied. In order to "
f"access the Kubeflow Pipelines Metadata locally, please "
f"change the metadata store configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Kubeflow Pipelines Metadata daemon (check the daemon"
"logs at %s in case you're not able to access the pipeline"
"metadata).",
self._log_file,
)
stop_kfp_metadata_daemon(self)
Stops the KFP Metadata daemon process if it is running.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def stop_kfp_metadata_daemon(self) -> None:
"""Stops the KFP Metadata daemon process if it is running."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
suspend(self)
Suspends the local k3d cluster.
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_running:
logger.info("Local kubeflow pipelines deployment not running.")
return
self.stop_kfp_metadata_daemon()
wait_until_metadata_store_ready(self, timeout=60)
Waits until the metadata store connection is ready.
Potentially an irrecoverable error could occur or the timeout could expire, so it checks for this.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timeout |
int |
The maximum time to wait for the metadata store to be ready. |
60 |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the metadata store is not ready after the timeout |
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def wait_until_metadata_store_ready(
self, timeout: int = DEFAULT_KFP_METADATA_DAEMON_TIMEOUT
) -> None:
"""Waits until the metadata store connection is ready.
Potentially an irrecoverable error could occur or the timeout could
expire, so it checks for this.
Args:
timeout: The maximum time to wait for the metadata store to be
ready.
Raises:
RuntimeError: if the metadata store is not ready after the timeout
"""
logger.info(
"Waiting for the Kubeflow metadata store to be ready (this might "
"take a few minutes)."
)
while True:
try:
# it doesn't matter what we call here as long as it exercises
# the MLMD connection
self.get_pipelines()
break
except Exception as e:
logger.info(
"The Kubeflow metadata store is not ready yet. Waiting for "
"10 seconds..."
)
if timeout <= 0:
raise RuntimeError(
f"An unexpected error was encountered while waiting for the "
f"Kubeflow metadata store to be functional: {str(e)}"
) from e
timeout -= 10
time.sleep(10)
logger.info("The Kubeflow metadata store is functional.")
inside_kfp_pod()
Returns if the current python process is running inside a KFP Pod.
Returns:
Type | Description |
---|---|
bool |
True if the current python process is running inside a KFP Pod, False otherwise. |
Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def inside_kfp_pod() -> bool:
"""Returns if the current python process is running inside a KFP Pod.
Returns:
True if the current python process is running inside a KFP Pod, False otherwise.
"""
if "KFP_POD_NAME" not in os.environ:
return False
try:
k8s_config.load_incluster_config()
return True
except k8s_config.ConfigException:
return False
orchestrators
special
Initialization of the Kubeflow ZenML orchestrator.
kubeflow_entrypoint_configuration
Implementation of the Kubeflow entrypoint configuration.
KubeflowEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on kubeflow.
This class writes a markdown file that will be displayed in the KFP UI.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
class KubeflowEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on kubeflow.
This class writes a markdown file that will be displayed in the KFP UI.
"""
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""Kubeflow specific entrypoint options.
The metadata ui path option expects a path where the markdown file
that will be displayed in the kubeflow UI should be written. The same
path needs to be added as an output artifact called
`mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.
Returns:
The set of custom entrypoint options.
"""
return {METADATA_UI_PATH_OPTION}
@classmethod
def get_custom_entrypoint_arguments(
cls, step: BaseStep, *args: Any, **kwargs: Any
) -> List[str]:
"""Sets the metadata ui path argument to the value passed in via the keyword args.
Args:
step: The step that is being executed.
*args: The positional arguments passed to the step.
**kwargs: The keyword arguments passed to the step.
Returns:
A list of strings that will be used as arguments to the step.
"""
return [
f"--{METADATA_UI_PATH_OPTION}",
kwargs[METADATA_UI_PATH_OPTION],
]
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the Kubeflow pipeline run name.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Kubeflow pipeline run name.
"""
k8s_config.load_incluster_config()
run_id = os.environ["KFP_RUN_ID"]
return kfp.Client().get_run(run_id).run.name # type: ignore[no-any-return]
def post_run(
self,
pipeline_name: str,
step_name: str,
pipeline_node: Pb2PipelineNode,
execution_info: Optional[data_types.ExecutionInfo] = None,
) -> None:
"""Writes a markdown file that will display information.
This will be about the step execution and input/output artifacts in the
KFP UI.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
pipeline_node: The pipeline node that is being executed.
execution_info: The execution info of the step.
"""
if execution_info:
utils.dump_ui_metadata(
node=pipeline_node,
execution_info=execution_info,
metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
)
get_custom_entrypoint_arguments(step, *args, **kwargs)
classmethod
Sets the metadata ui path argument to the value passed in via the keyword args.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
The step that is being executed. |
required |
*args |
Any |
The positional arguments passed to the step. |
() |
**kwargs |
Any |
The keyword arguments passed to the step. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
A list of strings that will be used as arguments to the step. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
cls, step: BaseStep, *args: Any, **kwargs: Any
) -> List[str]:
"""Sets the metadata ui path argument to the value passed in via the keyword args.
Args:
step: The step that is being executed.
*args: The positional arguments passed to the step.
**kwargs: The keyword arguments passed to the step.
Returns:
A list of strings that will be used as arguments to the step.
"""
return [
f"--{METADATA_UI_PATH_OPTION}",
kwargs[METADATA_UI_PATH_OPTION],
]
get_custom_entrypoint_options()
classmethod
Kubeflow specific entrypoint options.
The metadata ui path option expects a path where the markdown file
that will be displayed in the kubeflow UI should be written. The same
path needs to be added as an output artifact called
mlpipeline-ui-metadata
for the corresponding kfp.dsl.ContainerOp
.
Returns:
Type | Description |
---|---|
Set[str] |
The set of custom entrypoint options. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""Kubeflow specific entrypoint options.
The metadata ui path option expects a path where the markdown file
that will be displayed in the kubeflow UI should be written. The same
path needs to be added as an output artifact called
`mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.
Returns:
The set of custom entrypoint options.
"""
return {METADATA_UI_PATH_OPTION}
get_run_name(self, pipeline_name)
Returns the Kubeflow pipeline run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
str |
The Kubeflow pipeline run name. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the Kubeflow pipeline run name.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Kubeflow pipeline run name.
"""
k8s_config.load_incluster_config()
run_id = os.environ["KFP_RUN_ID"]
return kfp.Client().get_run(run_id).run.name # type: ignore[no-any-return]
post_run(self, pipeline_name, step_name, pipeline_node, execution_info=None)
Writes a markdown file that will display information.
This will be about the step execution and input/output artifacts in the KFP UI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
step_name |
str |
The name of the step. |
required |
pipeline_node |
PipelineNode |
The pipeline node that is being executed. |
required |
execution_info |
Optional[tfx.orchestration.portable.data_types.ExecutionInfo] |
The execution info of the step. |
None |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def post_run(
self,
pipeline_name: str,
step_name: str,
pipeline_node: Pb2PipelineNode,
execution_info: Optional[data_types.ExecutionInfo] = None,
) -> None:
"""Writes a markdown file that will display information.
This will be about the step execution and input/output artifacts in the
KFP UI.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
pipeline_node: The pipeline node that is being executed.
execution_info: The execution info of the step.
"""
if execution_info:
utils.dump_ui_metadata(
node=pipeline_node,
execution_info=execution_info,
metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
)
kubeflow_orchestrator
Implementation of the Kubeflow orchestrator.
KubeflowOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines using Kubeflow.
Attributes:
Name | Type | Description |
---|---|---|
custom_docker_base_image_name |
Optional[str] |
Name of a docker image that should be used as the base for the image that will be run on KFP pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/ |
kubeflow_pipelines_ui_port |
int |
A local port to which the KFP UI will be forwarded. |
kubeflow_hostname |
Optional[str] |
The hostname to use to talk to the Kubeflow Pipelines API. If not set, the hostname will be derived from the Kubernetes API proxy. |
kubernetes_context |
Optional[str] |
Optional name of a kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running |
synchronous |
bool |
If |
skip_local_validations |
bool |
If |
skip_cluster_provisioning |
bool |
If |
skip_ui_daemon_provisioning |
bool |
If |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
class KubeflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Kubeflow.
Attributes:
custom_docker_base_image_name: Name of a docker image that should be
used as the base for the image that will be run on KFP pods. If no
custom image is given, a basic image of the active ZenML version
will be used. **Note**: This image needs to have ZenML installed,
otherwise the pipeline execution will fail. For that reason, you
might want to extend the ZenML docker images found here:
https://hub.docker.com/r/zenmldocker/zenml/
kubeflow_pipelines_ui_port: A local port to which the KFP UI will be
forwarded.
kubeflow_hostname: The hostname to use to talk to the Kubeflow Pipelines
API. If not set, the hostname will be derived from the Kubernetes
API proxy.
kubernetes_context: Optional name of a kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running `kubectl config
current-context`.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on KFP.
skip_local_validations: If `True`, the local validations will be
skipped.
skip_cluster_provisioning: If `True`, the k3d cluster provisioning will
be skipped.
skip_ui_daemon_provisioning: If `True`, provisioning the KFP UI daemon
will be skipped.
"""
custom_docker_base_image_name: Optional[str] = None
kubeflow_pipelines_ui_port: int = DEFAULT_KFP_UI_PORT
kubeflow_hostname: Optional[str] = None
kubernetes_context: Optional[str] = None
synchronous: bool = False
skip_local_validations: bool = False
skip_cluster_provisioning: bool = False
skip_ui_daemon_provisioning: bool = False
# Class Configuration
FLAVOR: ClassVar[str] = KUBEFLOW_ORCHESTRATOR_FLAVOR
@staticmethod
def _get_k3d_cluster_name(uuid: UUID) -> str:
"""Returns the k3d cluster name corresponding to the orchestrator UUID.
Args:
uuid: The UUID of the orchestrator.
Returns:
The k3d cluster name.
"""
# k3d only allows cluster names with up to 32 characters; use the
# first 8 chars of the orchestrator UUID as identifier
return f"zenml-kubeflow-{str(uuid)[:8]}"
@staticmethod
def _get_k3d_kubernetes_context(uuid: UUID) -> str:
"""Gets the k3d kubernetes context.
Args:
uuid: The UUID of the orchestrator.
Returns:
The name of the kubernetes context associated with the k3d
cluster managed locally by ZenML corresponding to the orchestrator UUID.
"""
return f"k3d-{KubeflowOrchestrator._get_k3d_cluster_name(uuid)}"
@root_validator(skip_on_failure=True)
def set_default_kubernetes_context(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Pydantic root_validator.
This sets the default `kubernetes_context` value to the value that is
used to create the locally managed k3d cluster, if not explicitly set.
Args:
values: Values passed to the object constructor
Returns:
Values passed to the Pydantic constructor
"""
if not values.get("kubernetes_context"):
# not likely, due to Pydantic validation, but mypy complains
assert "uuid" in values
values["kubernetes_context"] = cls._get_k3d_kubernetes_context(
values["uuid"]
)
return values
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
"""Get the list of configured Kubernetes contexts and the active context.
Returns:
A tuple containing the list of configured Kubernetes contexts and
the active context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException:
return [], None
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Also check that requirements are met for local components.
Returns:
A `StackValidator` instance.
"""
def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
container_registry = stack.container_registry
# should not happen, because the stack validation takes care of
# this, but just in case
assert container_registry is not None
contexts, active_context = self.get_kubernetes_contexts()
if self.kubernetes_context not in contexts:
if not self.is_local:
return False, (
f"Could not find a Kubernetes context named "
f"'{self.kubernetes_context}' in the local Kubernetes "
f"configuration. Please make sure that the Kubernetes "
f"cluster is running and that the kubeconfig file is "
f"configured correctly. To list all configured "
f"contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
)
elif active_context and self.kubernetes_context != active_context:
logger.warning(
f"The Kubernetes context '{self.kubernetes_context}' "
f"configured for the Kubeflow orchestrator is not the "
f"same as the active context in the local Kubernetes "
f"configuration. If this is not deliberate, you should "
f"update the orchestrator's `kubernetes_context` field by "
f"running:\n\n"
f" `zenml orchestrator update {self.name} "
f"--kubernetes_context={active_context}`\n"
f"To list all configured contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
f"To set the active context to be the same as the one "
f"configured in the Kubeflow orchestrator and silence "
f"this warning, run:\n\n"
f" `kubectl config use-context "
f"{self.kubernetes_context}`\n"
)
silence_local_validations_msg = (
f"To silence this warning, set the "
f"`skip_local_validations` attribute to True in the "
f"orchestrator configuration by running:\n\n"
f" 'zenml orchestrator update {self.name} "
f"--skip_local_validations=True'\n"
)
if not self.skip_local_validations and not self.is_local:
# if the orchestrator is not running in a local k3d cluster,
# we cannot have any other local components in our stack,
# because we cannot mount the local path into the container.
# This may result in problems when running the pipeline, because
# the local components will not be available inside the
# Kubeflow containers.
# go through all stack components and identify those that
# advertise a local path where they persist information that
# they need to be available when running pipelines.
for stack_comp in stack.components.values():
local_path = stack_comp.local_path
if not local_path:
continue
return False, (
f"The Kubeflow orchestrator is configured to run "
f"pipelines in a remote Kubernetes cluster designated "
f"by the '{self.kubernetes_context}' configuration "
f"context, but the '{stack_comp.name}' "
f"{stack_comp.TYPE.value} is a local stack component "
f"and will not be available in the Kubeflow pipeline "
f"step.\nPlease ensure that you always use non-local "
f"stack components with a remote Kubeflow orchestrator, "
f"otherwise you may run into pipeline execution "
f"problems. You should use a flavor of "
f"{stack_comp.TYPE.value} other than "
f"'{stack_comp.FLAVOR}'.\n"
+ silence_local_validations_msg
)
# if the orchestrator is remote, the container registry must
# also be remote.
if container_registry.is_local:
return False, (
f"The Kubeflow orchestrator is configured to run "
f"pipelines in a remote Kubernetes cluster designated "
f"by the '{self.kubernetes_context}' configuration "
f"context, but the '{container_registry.name}' "
f"container registry URI '{container_registry.uri}' "
f"points to a local container registry. Please ensure "
f"that you always use non-local stack components with "
f"a remote Kubeflow orchestrator, otherwise you will "
f"run into problems. You should use a flavor of "
f"container registry other than "
f"'{container_registry.FLAVOR}'.\n"
+ silence_local_validations_msg
)
if not self.skip_local_validations and self.is_local:
# if the orchestrator is local, the container registry must
# also be local.
if not container_registry.is_local:
return False, (
f"The Kubeflow orchestrator is configured to run "
f"pipelines in a local k3d Kubernetes cluster "
f"designated by the '{self.kubernetes_context}' "
f"configuration context, but the container registry "
f"URI '{container_registry.uri}' doesn't match the "
f"expected format 'localhost:$PORT'. "
f"The local Kubeflow orchestrator only works with a "
f"local container registry because it cannot "
f"currently authenticate to external container "
f"registries. You should use a flavor of container "
f"registry other than '{container_registry.FLAVOR}'.\n"
+ silence_local_validations_msg
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_local_requirements,
)
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag.
Args:
pipeline_name: The name of the pipeline.
Returns:
The full docker image name including registry and tag.
"""
base_image_name = f"zenml-kubeflow:{pipeline_name}"
container_registry = Repository().active_stack.container_registry
if container_registry:
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/{base_image_name}"
else:
return base_image_name
@property
def is_local(self) -> bool:
"""Checks if the KFP orchestrator is running locally.
Returns:
`True` if the KFP orchestrator is running locally (i.e. in
the local k3d cluster managed by ZenML).
"""
return self.kubernetes_context == self._get_k3d_kubernetes_context(
self.uuid
)
@property
def root_directory(self) -> str:
"""Returns path to the root directory for all files concerning this orchestrator.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"kubeflow",
str(self.uuid),
)
@property
def pipeline_directory(self) -> str:
"""Returns path to a directory in which the kubeflow pipeline files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Builds a docker image for the current environment.
This function also uploads it to a container registry if configured.
Args:
pipeline: The pipeline to be deployed.
stack: The stack to be deployed.
runtime_configuration: The runtime configuration to be used.
"""
from zenml.utils import docker_utils
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug("Kubeflow docker container requirements: %s", requirements)
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
environment_vars=self._get_environment_vars_from_secrets(
pipeline.secrets
),
)
assert stack.container_registry # should never happen due to validation
stack.container_registry.push_image(image_name)
# Store the docker image digest in the runtime configuration so it gets
# tracked in the ZenStore
image_digest = docker_utils.get_image_digest(image_name) or image_name
runtime_configuration["docker_image"] = image_digest
@staticmethod
def _configure_container_op(container_op: dsl.ContainerOp) -> None:
"""Makes changes in place to the configuration of the container op.
Configures persistent mounted volumes for each stack component that
writes to a local path. Adds some labels to the container_op and applies
some functions to ir.
Args:
container_op: The kubeflow container operation to configure.
Raises:
ValueError: If the local path is not in the global config directory.
"""
# Path to a metadata file that will be displayed in the KFP UI
# This metadata file needs to be in a mounted emptyDir to avoid
# sporadic failures with the (not mature) PNS executor
# See these links for more information about limitations of PNS +
# security context:
# https://www.kubeflow.org/docs/components/pipelines/installation/localcluster-deployment/#deploying-kubeflow-pipelines
# https://argoproj.github.io/argo-workflows/empty-dir/
# KFP will switch to the Emissary executor (soon), when this emptyDir
# mount will not be necessary anymore, but for now it's still in alpha
# status (https://www.kubeflow.org/docs/components/pipelines/installation/choose-executor/#emissary-executor)
volumes: Dict[str, k8s_client.V1Volume] = {
"/outputs": k8s_client.V1Volume(
name="outputs", empty_dir=k8s_client.V1EmptyDirVolumeSource()
),
}
stack = Repository().active_stack
global_cfg_dir = io_utils.get_global_config_directory()
# go through all stack components and identify those that advertise
# a local path where they persist information that they need to be
# available when running pipelines. For those that do, mount them
# into the Kubeflow container.
has_local_repos = False
for stack_comp in stack.components.values():
local_path = stack_comp.local_path
if not local_path:
continue
# double-check this convention, just in case it wasn't respected
# as documented in `StackComponent.local_path`
if not local_path.startswith(global_cfg_dir):
raise ValueError(
f"Local path {local_path} for component {stack_comp.name} "
f"is not in the global config directory ({global_cfg_dir})."
)
has_local_repos = True
host_path = k8s_client.V1HostPathVolumeSource(
path=local_path, type="Directory"
)
volume_name = f"{stack_comp.TYPE.value}-{stack_comp.name}"
volumes[local_path] = k8s_client.V1Volume(
name=re.sub(r"[^0-9a-zA-Z-]+", "-", volume_name)
.strip("-")
.lower(),
host_path=host_path,
)
logger.debug(
"Adding host path volume for %s %s (path: %s) "
"in kubeflow pipelines container.",
stack_comp.TYPE.value,
stack_comp.name,
local_path,
)
container_op.add_pvolumes(volumes)
if has_local_repos:
if sys.platform == "win32":
# File permissions are not checked on Windows. This if clause
# prevents mypy from complaining about unused 'type: ignore'
# statements
pass
else:
# Run KFP containers in the context of the local UID/GID
# to ensure that the artifact and metadata stores can be shared
# with the local pipeline runs.
container_op.container.security_context = (
k8s_client.V1SecurityContext(
run_as_user=os.getuid(),
run_as_group=os.getgid(),
)
)
logger.debug(
"Setting security context UID and GID to local user/group "
"in kubeflow pipelines container."
)
# Add environment variables for Azure Blob Storage to pod in case they
# are set locally
# TODO [ENG-699]: remove this as soon as we implement credential
# handling
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
container_op.container.add_env_variable(
k8s_client.V1EnvVar(name=key, value=value)
)
# Add some pod labels to the container_op
for k, v in KFP_POD_LABELS.items():
container_op.add_pod_label(k, v)
# Mounts configmap containing Metadata gRPC server configuration.
container_op.apply(utils.mount_config_map_op("metadata-grpc-configmap"))
@staticmethod
def _configure_container_resources(
container_op: dsl.ContainerOp,
resource_configuration: "ResourceConfiguration",
) -> None:
"""Adds resource requirements to the container.
Args:
container_op: The kubeflow container operation to configure.
resource_configuration: The resource configuration to use for this
container.
"""
if resource_configuration.cpu_count is not None:
container_op = container_op.set_cpu_limit(
str(resource_configuration.cpu_count)
)
if resource_configuration.gpu_count is not None:
container_op = container_op.set_gpu_limit(
resource_configuration.gpu_count
)
if resource_configuration.memory is not None:
memory_limit = resource_configuration.memory[:-1]
container_op = container_op.set_memory_limit(memory_limit)
def prepare_or_run_pipeline(
self,
sorted_steps: List["BaseStep"],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates a kfp yaml file.
This functions as an intermediary representation of the pipeline which
is then deployed to the kubeflow pipelines instance.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()`
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these container_ops are then also configured onto each
container_op by pointing at the downstream steps.
This callable is then compiled into a kfp yaml file that is used as
the intermediary representation of the kubeflow pipeline.
This file, together with some metadata, runtime configurations is
then uploaded into the kubeflow pipelines cluster for execution.
Args:
sorted_steps: A list of steps sorted by their order in the
pipeline.
pipeline: The pipeline object.
pb2_pipeline: The pipeline object in protobuf format.
stack: The stack object.
runtime_configuration: The runtime configuration object.
Raises:
RuntimeError: If you try to run the pipelines in a notebook environment.
"""
# First check whether the code running in a notebook
if Environment.in_notebook():
raise RuntimeError(
"The Kubeflow orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
# Create a callable for future compilation into a dsl.Pipeline.
def _construct_kfp_pipeline() -> None:
"""Create a container_op for each step.
This should contain the name of the docker image and configures the
entrypoint of the docker image to run the step.
Additionally, this gives each container_op information about its
direct downstream steps.
If this callable is passed to the `_create_and_write_workflow()`
method of a KFPCompiler all dsl.ContainerOp instances will be
automatically added to a singular dsl.Pipeline instance.
"""
# Dictionary of container_ops index by the associated step name
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step in sorted_steps:
# The command will be needed to eventually call the python step
# within the docker container
command = (
KubeflowEntrypointConfiguration.get_entrypoint_command()
)
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
arguments = (
KubeflowEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
**{METADATA_UI_PATH_OPTION: metadata_ui_path},
)
)
# Create a container_op - the kubeflow equivalent of a step. It
# contains the name of the step, the name of the docker image,
# the command to use to run the step entrypoint
# (e.g. `python -m zenml.entrypoints.step_entrypoint`)
# and the arguments to be passed along with the command. Find
# out more about how these arguments are parsed and used
# in the base entrypoint `run()` method.
container_op = dsl.ContainerOp(
name=step.name,
image=image_name,
command=command,
arguments=arguments,
output_artifact_paths={
"mlpipeline-ui-metadata": metadata_ui_path,
},
)
# Mounts persistent volumes, configmaps and adds labels to the
# container op
self._configure_container_op(container_op=container_op)
if self.requires_resources_in_orchestration_environment(step):
self._configure_container_resources(
container_op=container_op,
resource_configuration=step.resource_configuration,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
# Update dictionary of container ops with the current one
step_name_to_container_op[step.name] = container_op
# Get a filepath to use to save the finished yaml to
assert runtime_configuration.run_name
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{runtime_configuration.run_name}.yaml"
)
# write the argo pipeline yaml
KFPCompiler()._create_and_write_workflow(
pipeline_func=_construct_kfp_pipeline,
pipeline_name=pipeline.name,
package_path=pipeline_file_path,
)
# using the kfp client uploads the pipeline to kubeflow pipelines and
# runs it there
self._upload_and_run_pipeline(
pipeline_name=pipeline.name,
pipeline_file_path=pipeline_file_path,
runtime_configuration=runtime_configuration,
enable_cache=pipeline.enable_cache,
)
def _upload_and_run_pipeline(
self,
pipeline_name: str,
pipeline_file_path: str,
runtime_configuration: "RuntimeConfiguration",
enable_cache: bool,
) -> None:
"""Tries to upload and run a KFP pipeline.
Args:
pipeline_name: Name of the pipeline.
pipeline_file_path: Path to the pipeline definition file.
runtime_configuration: Runtime configuration of the pipeline run.
enable_cache: Whether caching is enabled for this pipeline run.
"""
try:
logger.info(
"Running in kubernetes context '%s'.",
self.kubernetes_context,
)
# upload the pipeline to Kubeflow and start it
client = kfp.Client(
host=self.kubeflow_hostname,
kube_context=self.kubernetes_context,
)
if runtime_configuration.schedule:
try:
experiment = client.get_experiment(pipeline_name)
logger.info(
"A recurring run has already been created with this "
"pipeline. Creating new recurring run now.."
)
except (ValueError, ApiException):
experiment = client.create_experiment(pipeline_name)
logger.info(
"Creating a new recurring run for pipeline '%s'.. ",
pipeline_name,
)
logger.info(
"You can see all recurring runs under the '%s' experiment.'",
pipeline_name,
)
schedule = runtime_configuration.schedule
interval_seconds = (
schedule.interval_second.seconds
if schedule.interval_second
else None
)
result = client.create_recurring_run(
experiment_id=experiment.id,
job_name=runtime_configuration.run_name,
pipeline_package_path=pipeline_file_path,
enable_caching=enable_cache,
cron_expression=schedule.cron_expression,
start_time=schedule.utc_start_time,
end_time=schedule.utc_end_time,
interval_second=interval_seconds,
no_catchup=not schedule.catchup,
)
logger.info("Started recurring run with ID '%s'.", result.id)
else:
logger.info(
"No schedule detected. Creating a one-off pipeline run.."
)
result = client.create_run_from_pipeline_package(
pipeline_file_path,
arguments={},
run_name=runtime_configuration.run_name,
enable_caching=enable_cache,
)
logger.info(
"Started one-off pipeline run with ID '%s'.", result.run_id
)
if self.synchronous:
# TODO [ENG-698]: Allow configuration of the timeout as a
# runtime option
client.wait_for_run_completion(
run_id=result.run_id, timeout=1200
)
except urllib3.exceptions.HTTPError as error:
logger.warning(
f"Failed to upload Kubeflow pipeline: %s. "
f"Please make sure your kubernetes config is present and the "
f"{self.kubernetes_context} kubernetes context is configured "
f"correctly.",
error,
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "kubeflow_daemon.pid")
@property
def log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path of the daemon log file.
"""
return os.path.join(self.root_directory, "kubeflow_daemon.log")
@property
def _k3d_cluster_name(self) -> str:
"""Returns the K3D cluster name.
Returns:
The K3D cluster name.
"""
return self._get_k3d_cluster_name(self.uuid)
def _get_k3d_registry_name(self, port: int) -> str:
"""Returns the K3D registry name.
Args:
port: Port of the registry.
Returns:
The registry name.
"""
return f"k3d-zenml-kubeflow-registry.localhost:{port}"
@property
def _k3d_registry_config_path(self) -> str:
"""Returns the path to the K3D registry config yaml.
Returns:
str: Path to the K3D registry config yaml.
"""
return os.path.join(self.root_directory, "k3d_registry.yaml")
def _get_kfp_ui_daemon_port(self) -> int:
"""Port to use for the KFP UI daemon.
Returns:
Port to use for the KFP UI daemon.
"""
port = self.kubeflow_pipelines_ui_port
if port == DEFAULT_KFP_UI_PORT and not networking_utils.port_available(
port
):
# if the user didn't specify a specific port and the default
# port is occupied, fallback to a random open port
port = networking_utils.find_available_port()
return port
def list_manual_setup_steps(
self, container_registry_name: str, container_registry_path: str
) -> None:
"""Logs manual steps needed to setup the Kubeflow local orchestrator.
Args:
container_registry_name: Name of the container registry.
container_registry_path: Path to the container registry.
"""
if not self.is_local:
# Make sure we're not telling users to deploy Kubeflow on their
# remote clusters
logger.warning(
"This Kubeflow orchestrator is configured to use a non-local "
f"Kubernetes context {self.kubernetes_context}. Manually "
f"deploying Kubeflow Pipelines is only possible for local "
f"Kubeflow orchestrators."
)
return
global_config_dir_path = io_utils.get_global_config_directory()
kubeflow_commands = [
f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
]
logger.info(
"If you wish to spin up this Kubeflow local orchestrator manually, "
"please enter the following commands:\n"
)
logger.info("\n".join(kubeflow_commands))
@property
def is_provisioned(self) -> bool:
"""Returns if a local k3d cluster for this orchestrator exists.
Returns:
True if a local k3d cluster exists, False otherwise.
"""
if not local_deployment_utils.check_prerequisites(
skip_k3d=self.skip_cluster_provisioning or not self.is_local,
skip_kubectl=self.skip_cluster_provisioning
and self.skip_ui_daemon_provisioning,
):
# if any prerequisites are missing there is certainly no
# local deployment running
return False
return self.is_cluster_provisioned
@property
def is_running(self) -> bool:
"""Checks if the local k3d cluster and UI daemon are both running.
Returns:
True if the local k3d cluster and UI daemon for this orchestrator are both running.
"""
return (
self.is_provisioned
and self.is_cluster_running
and self.is_daemon_running
)
@property
def is_suspended(self) -> bool:
"""Checks if the local k3d cluster and UI daemon are both stopped.
Returns:
True if the cluster and daemon for this orchestrator are both stopped, False otherwise.
"""
return (
self.is_provisioned
and (self.skip_cluster_provisioning or not self.is_cluster_running)
and (self.skip_ui_daemon_provisioning or not self.is_daemon_running)
)
@property
def is_cluster_provisioned(self) -> bool:
"""Returns if the local k3d cluster for this orchestrator is provisioned.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
this always returns True.
Returns:
True if the local k3d cluster is provisioned, False otherwise.
"""
if self.skip_cluster_provisioning or not self.is_local:
return True
return local_deployment_utils.k3d_cluster_exists(
cluster_name=self._k3d_cluster_name
)
@property
def is_cluster_running(self) -> bool:
"""Returns if the local k3d cluster for this orchestrator is running.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
this always returns True.
Returns:
True if the local k3d cluster is running, False otherwise.
"""
if self.skip_cluster_provisioning or not self.is_local:
return True
return local_deployment_utils.k3d_cluster_running(
cluster_name=self._k3d_cluster_name
)
@property
def is_daemon_running(self) -> bool:
"""Returns if the local Kubeflow UI daemon for this orchestrator is running.
Returns:
True if the daemon is running, False otherwise.
"""
if self.skip_ui_daemon_provisioning:
return True
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
return check_if_daemon_is_running(self._pid_file_path)
else:
return True
def provision(self) -> None:
"""Provisions a local Kubeflow Pipelines deployment.
Raises:
ProvisioningError: If the provisioning fails.
"""
if self.skip_cluster_provisioning:
return
if self.is_running:
logger.info(
"Found already existing local Kubeflow Pipelines deployment. "
"If there are any issues with the existing deployment, please "
"run 'zenml stack down --force' to delete it."
)
return
if not local_deployment_utils.check_prerequisites():
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Please install 'k3d' and 'kubectl' and try again."
)
container_registry = Repository().active_stack.container_registry
# should not happen, because the stack validation takes care of this,
# but just in case
assert container_registry is not None
fileio.makedirs(self.root_directory)
if not self.is_local:
# don't provision any resources if using a remote KFP installation
return
logger.info("Provisioning local Kubeflow Pipelines deployment...")
container_registry_port = int(container_registry.uri.split(":")[-1])
container_registry_name = self._get_k3d_registry_name(
port=container_registry_port
)
local_deployment_utils.write_local_registry_yaml(
yaml_path=self._k3d_registry_config_path,
registry_name=container_registry_name,
registry_uri=container_registry.uri,
)
try:
local_deployment_utils.create_k3d_cluster(
cluster_name=self._k3d_cluster_name,
registry_name=container_registry_name,
registry_config_path=self._k3d_registry_config_path,
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
local_deployment_utils.deploy_kubeflow_pipelines(
kubernetes_context=kubernetes_context
)
artifact_store = Repository().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
kubernetes_context=kubernetes_context,
local_path=artifact_store.path,
)
except Exception as e:
logger.error(e)
logger.error(
"Unable to spin up local Kubeflow Pipelines deployment."
)
self.list_manual_setup_steps(
container_registry_name, self._k3d_registry_config_path
)
self.deprovision()
def deprovision(self) -> None:
"""Deprovisions a local Kubeflow Pipelines deployment."""
if self.skip_cluster_provisioning:
return
if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if self.is_local:
# don't deprovision any resources if using a remote KFP installation
local_deployment_utils.delete_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
if fileio.exists(self.log_file):
fileio.remove(self.log_file)
def resume(self) -> None:
"""Resumes the local k3d cluster.
Raises:
ProvisioningError: If the k3d cluster is not provisioned.
"""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
if not self.is_provisioned:
raise ProvisioningError(
"Unable to resume local kubeflow pipelines deployment: No "
"resources provisioned for local deployment."
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
if (
not self.skip_cluster_provisioning
and self.is_local
and not self.is_cluster_running
):
# don't resume any resources if using a remote KFP installation
local_deployment_utils.start_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.wait_until_kubeflow_pipelines_ready(
kubernetes_context=kubernetes_context
)
if not self.is_daemon_running:
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
kubernetes_context=kubernetes_context,
)
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_provisioned:
logger.info("Local kubeflow pipelines deployment not provisioned.")
return
if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if (
not self.skip_cluster_provisioning
and self.is_local
and self.is_cluster_running
):
# don't suspend any resources if using a remote KFP installation
local_deployment_utils.stop_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
def _get_environment_vars_from_secrets(
self, secrets: List[str]
) -> Dict[str, str]:
"""Get key-value pairs from list of secrets provided by the user.
Args:
secrets: List of secrets provided by the user.
Returns:
A dictionary of key-value pairs.
Raises:
ProvisioningError: If the stack has no secrets manager.
"""
environment_vars: Dict[str, str] = {}
secret_manager = Repository().active_stack.secrets_manager
if secrets and secret_manager:
for secret in secrets:
secret_schema = secret_manager.get_secret(secret)
environment_vars.update(secret_schema.content)
elif secrets and not secret_manager:
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
f"You passed in the following secrets: { ', '.join(secrets) }, "
"however, no secrets manager is registered for the current "
"stack."
)
else:
# No secrets provided by the user.
pass
return environment_vars
is_cluster_provisioned: bool
property
readonly
Returns if the local k3d cluster for this orchestrator is provisioned.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.
Returns:
Type | Description |
---|---|
bool |
True if the local k3d cluster is provisioned, False otherwise. |
is_cluster_running: bool
property
readonly
Returns if the local k3d cluster for this orchestrator is running.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.
Returns:
Type | Description |
---|---|
bool |
True if the local k3d cluster is running, False otherwise. |
is_daemon_running: bool
property
readonly
Returns if the local Kubeflow UI daemon for this orchestrator is running.
Returns:
Type | Description |
---|---|
bool |
True if the daemon is running, False otherwise. |
is_local: bool
property
readonly
Checks if the KFP orchestrator is running locally.
Returns:
Type | Description |
---|---|
bool |
|
is_provisioned: bool
property
readonly
Returns if a local k3d cluster for this orchestrator exists.
Returns:
Type | Description |
---|---|
bool |
True if a local k3d cluster exists, False otherwise. |
is_running: bool
property
readonly
Checks if the local k3d cluster and UI daemon are both running.
Returns:
Type | Description |
---|---|
bool |
True if the local k3d cluster and UI daemon for this orchestrator are both running. |
is_suspended: bool
property
readonly
Checks if the local k3d cluster and UI daemon are both stopped.
Returns:
Type | Description |
---|---|
bool |
True if the cluster and daemon for this orchestrator are both stopped, False otherwise. |
log_file: str
property
readonly
Path of the daemon log file.
Returns:
Type | Description |
---|---|
str |
Path of the daemon log file. |
pipeline_directory: str
property
readonly
Returns path to a directory in which the kubeflow pipeline files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Returns path to the root directory for all files concerning this orchestrator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Also check that requirements are met for local components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
deprovision(self)
Deprovisions a local Kubeflow Pipelines deployment.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def deprovision(self) -> None:
"""Deprovisions a local Kubeflow Pipelines deployment."""
if self.skip_cluster_provisioning:
return
if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if self.is_local:
# don't deprovision any resources if using a remote KFP installation
local_deployment_utils.delete_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
if fileio.exists(self.log_file):
fileio.remove(self.log_file)
get_docker_image_name(self, pipeline_name)
Returns the full docker image name including registry and tag.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
str |
The full docker image name including registry and tag. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Returns the full docker image name including registry and tag.
Args:
pipeline_name: The name of the pipeline.
Returns:
The full docker image name including registry and tag.
"""
base_image_name = f"zenml-kubeflow:{pipeline_name}"
container_registry = Repository().active_stack.container_registry
if container_registry:
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/{base_image_name}"
else:
return base_image_name
get_kubernetes_contexts(self)
Get the list of configured Kubernetes contexts and the active context.
Returns:
Type | Description |
---|---|
Tuple[List[str], Optional[str]] |
A tuple containing the list of configured Kubernetes contexts and the active context. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
"""Get the list of configured Kubernetes contexts and the active context.
Returns:
A tuple containing the list of configured Kubernetes contexts and
the active context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException:
return [], None
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
list_manual_setup_steps(self, container_registry_name, container_registry_path)
Logs manual steps needed to setup the Kubeflow local orchestrator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
container_registry_name |
str |
Name of the container registry. |
required |
container_registry_path |
str |
Path to the container registry. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def list_manual_setup_steps(
self, container_registry_name: str, container_registry_path: str
) -> None:
"""Logs manual steps needed to setup the Kubeflow local orchestrator.
Args:
container_registry_name: Name of the container registry.
container_registry_path: Path to the container registry.
"""
if not self.is_local:
# Make sure we're not telling users to deploy Kubeflow on their
# remote clusters
logger.warning(
"This Kubeflow orchestrator is configured to use a non-local "
f"Kubernetes context {self.kubernetes_context}. Manually "
f"deploying Kubeflow Pipelines is only possible for local "
f"Kubeflow orchestrators."
)
return
global_config_dir_path = io_utils.get_global_config_directory()
kubeflow_commands = [
f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
]
logger.info(
"If you wish to spin up this Kubeflow local orchestrator manually, "
"please enter the following commands:\n"
)
logger.info("\n".join(kubeflow_commands))
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)
Creates a kfp yaml file.
This functions as an intermediary representation of the pipeline which is then deployed to the kubeflow pipelines instance.
How it works:
Before this method is called the prepare_pipeline_deployment()
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
container_ops for each step (_construct_kfp_pipeline
).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these container_ops are then also configured onto each
container_op by pointing at the downstream steps.
This callable is then compiled into a kfp yaml file that is used as the intermediary representation of the kubeflow pipeline.
This file, together with some metadata, runtime configurations is then uploaded into the kubeflow pipelines cluster for execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sorted_steps |
List[BaseStep] |
A list of steps sorted by their order in the pipeline. |
required |
pipeline |
BasePipeline |
The pipeline object. |
required |
pb2_pipeline |
Pipeline |
The pipeline object in protobuf format. |
required |
stack |
Stack |
The stack object. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration object. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If you try to run the pipelines in a notebook environment. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_or_run_pipeline(
self,
sorted_steps: List["BaseStep"],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates a kfp yaml file.
This functions as an intermediary representation of the pipeline which
is then deployed to the kubeflow pipelines instance.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()`
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these container_ops are then also configured onto each
container_op by pointing at the downstream steps.
This callable is then compiled into a kfp yaml file that is used as
the intermediary representation of the kubeflow pipeline.
This file, together with some metadata, runtime configurations is
then uploaded into the kubeflow pipelines cluster for execution.
Args:
sorted_steps: A list of steps sorted by their order in the
pipeline.
pipeline: The pipeline object.
pb2_pipeline: The pipeline object in protobuf format.
stack: The stack object.
runtime_configuration: The runtime configuration object.
Raises:
RuntimeError: If you try to run the pipelines in a notebook environment.
"""
# First check whether the code running in a notebook
if Environment.in_notebook():
raise RuntimeError(
"The Kubeflow orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
# Create a callable for future compilation into a dsl.Pipeline.
def _construct_kfp_pipeline() -> None:
"""Create a container_op for each step.
This should contain the name of the docker image and configures the
entrypoint of the docker image to run the step.
Additionally, this gives each container_op information about its
direct downstream steps.
If this callable is passed to the `_create_and_write_workflow()`
method of a KFPCompiler all dsl.ContainerOp instances will be
automatically added to a singular dsl.Pipeline instance.
"""
# Dictionary of container_ops index by the associated step name
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step in sorted_steps:
# The command will be needed to eventually call the python step
# within the docker container
command = (
KubeflowEntrypointConfiguration.get_entrypoint_command()
)
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
arguments = (
KubeflowEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
**{METADATA_UI_PATH_OPTION: metadata_ui_path},
)
)
# Create a container_op - the kubeflow equivalent of a step. It
# contains the name of the step, the name of the docker image,
# the command to use to run the step entrypoint
# (e.g. `python -m zenml.entrypoints.step_entrypoint`)
# and the arguments to be passed along with the command. Find
# out more about how these arguments are parsed and used
# in the base entrypoint `run()` method.
container_op = dsl.ContainerOp(
name=step.name,
image=image_name,
command=command,
arguments=arguments,
output_artifact_paths={
"mlpipeline-ui-metadata": metadata_ui_path,
},
)
# Mounts persistent volumes, configmaps and adds labels to the
# container op
self._configure_container_op(container_op=container_op)
if self.requires_resources_in_orchestration_environment(step):
self._configure_container_resources(
container_op=container_op,
resource_configuration=step.resource_configuration,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
# Update dictionary of container ops with the current one
step_name_to_container_op[step.name] = container_op
# Get a filepath to use to save the finished yaml to
assert runtime_configuration.run_name
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{runtime_configuration.run_name}.yaml"
)
# write the argo pipeline yaml
KFPCompiler()._create_and_write_workflow(
pipeline_func=_construct_kfp_pipeline,
pipeline_name=pipeline.name,
package_path=pipeline_file_path,
)
# using the kfp client uploads the pipeline to kubeflow pipelines and
# runs it there
self._upload_and_run_pipeline(
pipeline_name=pipeline.name,
pipeline_file_path=pipeline_file_path,
runtime_configuration=runtime_configuration,
enable_cache=pipeline.enable_cache,
)
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Builds a docker image for the current environment.
This function also uploads it to a container registry if configured.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
The pipeline to be deployed. |
required |
stack |
Stack |
The stack to be deployed. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration to be used. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Builds a docker image for the current environment.
This function also uploads it to a container registry if configured.
Args:
pipeline: The pipeline to be deployed.
stack: The stack to be deployed.
runtime_configuration: The runtime configuration to be used.
"""
from zenml.utils import docker_utils
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug("Kubeflow docker container requirements: %s", requirements)
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
environment_vars=self._get_environment_vars_from_secrets(
pipeline.secrets
),
)
assert stack.container_registry # should never happen due to validation
stack.container_registry.push_image(image_name)
# Store the docker image digest in the runtime configuration so it gets
# tracked in the ZenStore
image_digest = docker_utils.get_image_digest(image_name) or image_name
runtime_configuration["docker_image"] = image_digest
provision(self)
Provisions a local Kubeflow Pipelines deployment.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the provisioning fails. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def provision(self) -> None:
"""Provisions a local Kubeflow Pipelines deployment.
Raises:
ProvisioningError: If the provisioning fails.
"""
if self.skip_cluster_provisioning:
return
if self.is_running:
logger.info(
"Found already existing local Kubeflow Pipelines deployment. "
"If there are any issues with the existing deployment, please "
"run 'zenml stack down --force' to delete it."
)
return
if not local_deployment_utils.check_prerequisites():
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Please install 'k3d' and 'kubectl' and try again."
)
container_registry = Repository().active_stack.container_registry
# should not happen, because the stack validation takes care of this,
# but just in case
assert container_registry is not None
fileio.makedirs(self.root_directory)
if not self.is_local:
# don't provision any resources if using a remote KFP installation
return
logger.info("Provisioning local Kubeflow Pipelines deployment...")
container_registry_port = int(container_registry.uri.split(":")[-1])
container_registry_name = self._get_k3d_registry_name(
port=container_registry_port
)
local_deployment_utils.write_local_registry_yaml(
yaml_path=self._k3d_registry_config_path,
registry_name=container_registry_name,
registry_uri=container_registry.uri,
)
try:
local_deployment_utils.create_k3d_cluster(
cluster_name=self._k3d_cluster_name,
registry_name=container_registry_name,
registry_config_path=self._k3d_registry_config_path,
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
local_deployment_utils.deploy_kubeflow_pipelines(
kubernetes_context=kubernetes_context
)
artifact_store = Repository().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
kubernetes_context=kubernetes_context,
local_path=artifact_store.path,
)
except Exception as e:
logger.error(e)
logger.error(
"Unable to spin up local Kubeflow Pipelines deployment."
)
self.list_manual_setup_steps(
container_registry_name, self._k3d_registry_config_path
)
self.deprovision()
resume(self)
Resumes the local k3d cluster.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the k3d cluster is not provisioned. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def resume(self) -> None:
"""Resumes the local k3d cluster.
Raises:
ProvisioningError: If the k3d cluster is not provisioned.
"""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
if not self.is_provisioned:
raise ProvisioningError(
"Unable to resume local kubeflow pipelines deployment: No "
"resources provisioned for local deployment."
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
if (
not self.skip_cluster_provisioning
and self.is_local
and not self.is_cluster_running
):
# don't resume any resources if using a remote KFP installation
local_deployment_utils.start_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.wait_until_kubeflow_pipelines_ready(
kubernetes_context=kubernetes_context
)
if not self.is_daemon_running:
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
kubernetes_context=kubernetes_context,
)
set_default_kubernetes_context(values)
classmethod
Pydantic root_validator.
This sets the default kubernetes_context
value to the value that is
used to create the locally managed k3d cluster, if not explicitly set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
Dict[str, Any] |
Values passed to the object constructor |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Values passed to the Pydantic constructor |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_default_kubernetes_context(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Pydantic root_validator.
This sets the default `kubernetes_context` value to the value that is
used to create the locally managed k3d cluster, if not explicitly set.
Args:
values: Values passed to the object constructor
Returns:
Values passed to the Pydantic constructor
"""
if not values.get("kubernetes_context"):
# not likely, due to Pydantic validation, but mypy complains
assert "uuid" in values
values["kubernetes_context"] = cls._get_k3d_kubernetes_context(
values["uuid"]
)
return values
suspend(self)
Suspends the local k3d cluster.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_provisioned:
logger.info("Local kubeflow pipelines deployment not provisioned.")
return
if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if (
not self.skip_cluster_provisioning
and self.is_local
and self.is_cluster_running
):
# don't suspend any resources if using a remote KFP installation
local_deployment_utils.stop_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils
Utils for the local Kubeflow deployment behaviors.
add_hostpath_to_kubeflow_pipelines(kubernetes_context, local_path)
Patches the Kubeflow Pipelines deployment to mount a local folder.
This folder serves as a hostpath for visualization purposes.
This function reconfigures the Kubeflow pipelines deployment to use a shared local folder to support loading the TensorBoard viewer and other pipeline visualization results from a local artifact store, as described here:
https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context on which Kubeflow Pipelines should be patched. |
required |
local_path |
str |
The path to the local folder to mount as a hostpath. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def add_hostpath_to_kubeflow_pipelines(
kubernetes_context: str, local_path: str
) -> None:
"""Patches the Kubeflow Pipelines deployment to mount a local folder.
This folder serves as a hostpath for visualization purposes.
This function reconfigures the Kubeflow pipelines deployment to use a
shared local folder to support loading the TensorBoard viewer and other
pipeline visualization results from a local artifact store, as described
here:
https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md
Args:
kubernetes_context: The kubernetes context on which Kubeflow Pipelines
should be patched.
local_path: The path to the local folder to mount as a hostpath.
"""
logger.info("Patching Kubeflow Pipelines to mount a local folder.")
pod_template = {
"spec": {
"serviceAccountName": "kubeflow-pipelines-viewer",
"containers": [
{
"volumeMounts": [
{
"mountPath": local_path,
"name": "local-artifact-store",
}
]
}
],
"volumes": [
{
"hostPath": {
"path": local_path,
"type": "Directory",
},
"name": "local-artifact-store",
}
],
}
}
pod_template_json = json.dumps(pod_template, indent=2)
config_map_data = {"data": {"viewer-pod-template.json": pod_template_json}}
config_map_data_json = json.dumps(config_map_data, indent=2)
logger.debug(
"Adding host path volume for local path `%s` to kubeflow pipeline"
"viewer pod template configuration.",
local_path,
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"-n",
"kubeflow",
"patch",
"configmap/ml-pipeline-ui-configmap",
"--type",
"merge",
"-p",
config_map_data_json,
]
)
deployment_patch = {
"spec": {
"template": {
"spec": {
"containers": [
{
"name": "ml-pipeline-ui",
"volumeMounts": [
{
"mountPath": local_path,
"name": "local-artifact-store",
}
],
}
],
"volumes": [
{
"hostPath": {
"path": local_path,
"type": "Directory",
},
"name": "local-artifact-store",
}
],
}
}
}
}
deployment_patch_json = json.dumps(deployment_patch, indent=2)
logger.debug(
"Adding host path volume for local path `%s` to the kubeflow UI",
local_path,
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"-n",
"kubeflow",
"patch",
"deployment/ml-pipeline-ui",
"--type",
"strategic",
"-p",
deployment_patch_json,
]
)
wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
logger.info("Finished patching Kubeflow Pipelines setup.")
check_prerequisites(skip_k3d=False, skip_kubectl=False)
Checks prerequisites for a local kubeflow pipelines deployment.
It makes sure they are installed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_k3d |
bool |
Whether to skip the check for the k3d command. |
False |
skip_kubectl |
bool |
Whether to skip the check for the kubectl command. |
False |
Returns:
Type | Description |
---|---|
bool |
Whether all prerequisites are installed. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def check_prerequisites(
skip_k3d: bool = False, skip_kubectl: bool = False
) -> bool:
"""Checks prerequisites for a local kubeflow pipelines deployment.
It makes sure they are installed.
Args:
skip_k3d: Whether to skip the check for the k3d command.
skip_kubectl: Whether to skip the check for the kubectl command.
Returns:
Whether all prerequisites are installed.
"""
k3d_installed = skip_k3d or shutil.which("k3d") is not None
kubectl_installed = skip_kubectl or shutil.which("kubectl") is not None
logger.debug(
"Local kubeflow deployment prerequisites: K3D - %s, Kubectl - %s",
k3d_installed,
kubectl_installed,
)
return k3d_installed and kubectl_installed
create_k3d_cluster(cluster_name, registry_name, registry_config_path)
Creates a K3D cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to create. |
required |
registry_name |
str |
Name of the registry to create for this cluster. |
required |
registry_config_path |
str |
Path to the registry config file. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def create_k3d_cluster(
cluster_name: str, registry_name: str, registry_config_path: str
) -> None:
"""Creates a K3D cluster.
Args:
cluster_name: Name of the cluster to create.
registry_name: Name of the registry to create for this cluster.
registry_config_path: Path to the registry config file.
"""
logger.info("Creating local K3D cluster '%s'.", cluster_name)
global_config_dir_path = io_utils.get_global_config_directory()
subprocess.check_call(
[
"k3d",
"cluster",
"create",
cluster_name,
"--image",
K3S_IMAGE_NAME,
"--registry-create",
registry_name,
"--registry-config",
registry_config_path,
"--volume",
f"{global_config_dir_path}:{global_config_dir_path}",
]
)
logger.info("Finished K3D cluster creation.")
delete_k3d_cluster(cluster_name)
Deletes a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to delete. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def delete_k3d_cluster(cluster_name: str) -> None:
"""Deletes a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to delete.
"""
subprocess.check_call(["k3d", "cluster", "delete", cluster_name])
logger.info("Deleted local k3d cluster '%s'.", cluster_name)
deploy_kubeflow_pipelines(kubernetes_context)
Deploys Kubeflow Pipelines.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context on which Kubeflow Pipelines should be deployed. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def deploy_kubeflow_pipelines(kubernetes_context: str) -> None:
"""Deploys Kubeflow Pipelines.
Args:
kubernetes_context: The kubernetes context on which Kubeflow Pipelines
should be deployed.
"""
logger.info("Deploying Kubeflow Pipelines.")
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"apply",
"-k",
f"github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
]
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"wait",
"--timeout=60s",
"--for",
"condition=established",
"crd/applications.app.k8s.io",
]
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"apply",
"-k",
f"github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
]
)
wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
logger.info("Finished Kubeflow Pipelines setup.")
k3d_cluster_exists(cluster_name)
Checks whether there exists a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the cluster exists. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_exists(cluster_name: str) -> bool:
"""Checks whether there exists a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to check.
Returns:
Whether the cluster exists.
"""
output = subprocess.check_output(
["k3d", "cluster", "list", "--output", "json"]
)
clusters = json.loads(output)
for cluster in clusters:
if cluster["name"] == cluster_name:
return True
return False
k3d_cluster_running(cluster_name)
Checks whether the K3D cluster with the given name is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the cluster is running. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_running(cluster_name: str) -> bool:
"""Checks whether the K3D cluster with the given name is running.
Args:
cluster_name: Name of the cluster to check.
Returns:
Whether the cluster is running.
"""
output = subprocess.check_output(
["k3d", "cluster", "list", "--output", "json"]
)
clusters = json.loads(output)
for cluster in clusters:
if cluster["name"] == cluster_name:
server_count: int = cluster["serversCount"]
servers_running: int = cluster["serversRunning"]
return servers_running == server_count
return False
kubeflow_pipelines_ready(kubernetes_context)
Returns whether all Kubeflow Pipelines pods are ready.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context in which the pods should be checked. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether all Kubeflow Pipelines pods are ready. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def kubeflow_pipelines_ready(kubernetes_context: str) -> bool:
"""Returns whether all Kubeflow Pipelines pods are ready.
Args:
kubernetes_context: The kubernetes context in which the pods
should be checked.
Returns:
Whether all Kubeflow Pipelines pods are ready.
"""
try:
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"wait",
"--for",
"condition=ready",
"--timeout=0s",
"pods",
"-l",
"application-crd-id=kubeflow-pipelines",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return True
except subprocess.CalledProcessError:
return False
start_k3d_cluster(cluster_name)
Starts a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to start. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_k3d_cluster(cluster_name: str) -> None:
"""Starts a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to start.
"""
subprocess.check_call(["k3d", "cluster", "start", cluster_name])
logger.info("Started local k3d cluster '%s'.", cluster_name)
start_kfp_ui_daemon(pid_file_path, log_file_path, port, kubernetes_context)
Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines UI is accessible in the browser.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file_path |
str |
Path where the file with the daemons process ID should be written. |
required |
log_file_path |
str |
Path to a file where the daemon logs should be written. |
required |
port |
int |
Port on which the UI should be accessible. |
required |
kubernetes_context |
str |
The kubernetes context for the cluster where Kubeflow Pipelines is running. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_kfp_ui_daemon(
pid_file_path: str,
log_file_path: str,
port: int,
kubernetes_context: str,
) -> None:
"""Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines UI is accessible in the browser.
Args:
pid_file_path: Path where the file with the daemons process ID should
be written.
log_file_path: Path to a file where the daemon logs should be written.
port: Port on which the UI should be accessible.
kubernetes_context: The kubernetes context for the cluster where
Kubeflow Pipelines is running.
"""
command = [
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"port-forward",
"svc/ml-pipeline-ui",
f"{port}:80",
]
if not networking_utils.port_available(port):
modified_command = command.copy()
modified_command[-1] = "PORT:80"
logger.warning(
"Unable to port-forward Kubeflow Pipelines UI to local port %d "
"because the port is occupied. In order to access the Kubeflow "
"Pipelines UI at http://localhost:PORT/, please run '%s' in a "
"separate command line shell (replace PORT with a free port of "
"your choice).",
port,
" ".join(modified_command),
)
elif sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubeflow Pipelines UI at "
"http://localhost:%d/, please run '%s' in a separate command "
"line shell.",
port,
" ".join(command),
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Port-forwards the Kubeflow Pipelines UI pod."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function, pid_file=pid_file_path, log_file=log_file_path
)
logger.info(
"Started Kubeflow Pipelines UI daemon (check the daemon logs at %s "
"in case you're not able to view the UI). The Kubeflow Pipelines "
"UI should now be accessible at http://localhost:%d/.",
log_file_path,
port,
)
stop_k3d_cluster(cluster_name)
Stops a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to stop. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_k3d_cluster(cluster_name: str) -> None:
"""Stops a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to stop.
"""
subprocess.check_call(["k3d", "cluster", "stop", cluster_name])
logger.info("Stopped local k3d cluster '%s'.", cluster_name)
stop_kfp_ui_daemon(pid_file_path)
Stops the KFP UI daemon process if it is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file_path |
str |
Path to the file with the daemons process ID. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_kfp_ui_daemon(pid_file_path: str) -> None:
"""Stops the KFP UI daemon process if it is running.
Args:
pid_file_path: Path to the file with the daemons process ID.
"""
if fileio.exists(pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(pid_file_path)
fileio.remove(pid_file_path)
logger.info("Stopped Kubeflow Pipelines UI daemon.")
wait_until_kubeflow_pipelines_ready(kubernetes_context)
Waits until all Kubeflow Pipelines pods are ready.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context in which the pods should be checked. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def wait_until_kubeflow_pipelines_ready(kubernetes_context: str) -> None:
"""Waits until all Kubeflow Pipelines pods are ready.
Args:
kubernetes_context: The kubernetes context in which the pods
should be checked.
"""
logger.info(
"Waiting for all Kubeflow Pipelines pods to be ready (this might "
"take a few minutes)."
)
while True:
logger.info("Current pod status:")
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"get",
"pods",
]
)
if kubeflow_pipelines_ready(kubernetes_context=kubernetes_context):
break
logger.info("One or more pods not ready yet, waiting for 30 seconds...")
time.sleep(30)
write_local_registry_yaml(yaml_path, registry_name, registry_uri)
Writes a K3D registry config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_path |
str |
Path where the config file should be written to. |
required |
registry_name |
str |
Name of the registry. |
required |
registry_uri |
str |
URI of the registry. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def write_local_registry_yaml(
yaml_path: str, registry_name: str, registry_uri: str
) -> None:
"""Writes a K3D registry config file.
Args:
yaml_path: Path where the config file should be written to.
registry_name: Name of the registry.
registry_uri: URI of the registry.
"""
yaml_content = {
"mirrors": {registry_uri: {"endpoint": [f"http://{registry_name}"]}}
}
yaml_utils.write_yaml(yaml_path, yaml_content)
utils
Utils for ZenML Kubeflow orchestrators implementation.
dump_ui_metadata(node, execution_info, metadata_ui_path)
Dump KFP UI metadata json file for visualization purpose.
For general components we just render a simple Markdown file for exec_properties/inputs/outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
PipelineNode |
associated TFX node. |
required |
execution_info |
ExecutionInfo |
runtime execution info for this component, including materialized inputs/outputs/execution properties and id. |
required |
metadata_ui_path |
str |
path to dump ui metadata. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def dump_ui_metadata(
node: PipelineNode,
execution_info: data_types.ExecutionInfo,
metadata_ui_path: str,
) -> None:
"""Dump KFP UI metadata json file for visualization purpose.
For general components we just render a simple Markdown file for
exec_properties/inputs/outputs.
Args:
node: associated TFX node.
execution_info: runtime execution info for this component, including
materialized inputs/outputs/execution properties and id.
metadata_ui_path: path to dump ui metadata.
"""
exec_properties_list = [
"**{}**: {}".format(
_sanitize_underscore(name), _sanitize_underscore(exec_property)
)
for name, exec_property in execution_info.exec_properties.items()
]
src_str_exec_properties = "# Execution properties:\n{}".format(
"\n\n".join(exec_properties_list) or "No execution property."
)
def _dump_input_populated_artifacts(
node_inputs: MutableMapping[str, InputSpec],
name_to_artifacts: Dict[str, List[artifact.Artifact]],
) -> List[str]:
"""Dump artifacts markdown string for inputs.
Args:
node_inputs: maps from input name to input sepc proto.
name_to_artifacts: maps from input key to list of populated artifacts.
Returns:
A list of dumped markdown string, each of which represents a channel.
"""
rendered_list = []
for name, spec in node_inputs.items():
# Need to look for materialized artifacts in the execution decision.
rendered_artifacts = "".join(
[
_render_artifact_as_mdstr(single_artifact)
for single_artifact in name_to_artifacts.get(name, [])
]
)
# There must be at least a channel in a input, and all channels in
# a input share the same artifact type.
artifact_type = spec.channels[0].artifact_query.type.name
rendered_list.append(
"## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
name=_sanitize_underscore(name),
channel_type=_sanitize_underscore(artifact_type),
artifacts=rendered_artifacts,
)
)
return rendered_list
def _dump_output_populated_artifacts(
node_outputs: MutableMapping[str, OutputSpec],
name_to_artifacts: Dict[str, List[artifact.Artifact]],
) -> List[str]:
"""Dump artifacts markdown string for outputs.
Args:
node_outputs: maps from output name to output sepc proto.
name_to_artifacts: maps from output key to list of populated
artifacts.
Returns:
A list of dumped markdown string, each of which represents a channel.
"""
rendered_list = []
for name, spec in node_outputs.items():
# Need to look for materialized artifacts in the execution decision.
rendered_artifacts = "".join(
[
_render_artifact_as_mdstr(single_artifact)
for single_artifact in name_to_artifacts.get(name, [])
]
)
# There must be at least a channel in a input, and all channels
# in a input share the same artifact type.
artifact_type = spec.artifact_spec.type.name
rendered_list.append(
"## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
name=_sanitize_underscore(name),
channel_type=_sanitize_underscore(artifact_type),
artifacts=rendered_artifacts,
)
)
return rendered_list
src_str_inputs = "# Inputs:\n{}".format(
"".join(
_dump_input_populated_artifacts(
node_inputs=node.inputs.inputs,
name_to_artifacts=execution_info.input_dict or {},
)
)
or "No input."
)
src_str_outputs = "# Outputs:\n{}".format(
"".join(
_dump_output_populated_artifacts(
node_outputs=node.outputs.outputs,
name_to_artifacts=execution_info.output_dict or {},
)
)
or "No output."
)
outputs = [
{
"storage": "inline",
"source": "{exec_properties}\n\n{inputs}\n\n{outputs}".format(
exec_properties=src_str_exec_properties,
inputs=src_str_inputs,
outputs=src_str_outputs,
),
"type": "markdown",
}
]
# Add TensorBoard view for ModelRun outputs.
for name, spec in node.outputs.outputs.items():
if (
spec.artifact_spec.type.name
== standard_artifacts.ModelRun.TYPE_NAME
or spec.artifact_spec.type.name == ModelArtifact.TYPE_NAME
):
output_model = execution_info.output_dict[name][0]
source = output_model.uri
# For local artifact repository, use a path that is relative to
# the point where the local artifact folder is mounted as a volume
artifact_store = Repository().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
source = os.path.relpath(source, artifact_store.path)
source = f"volume://local-artifact-store/{source}"
# Add TensorBoard view.
tensorboard_output = {
"type": "tensorboard",
"source": source,
}
outputs.append(tensorboard_output)
metadata_dict = {"outputs": outputs}
with open(metadata_ui_path, "w") as f:
json.dump(metadata_dict, f)
mount_config_map_op(config_map_name)
Mounts all key-value pairs found in the named Kubernetes ConfigMap.
All key-value pairs in the ConfigMap are mounted as environment variables.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_map_name |
str |
The name of the ConfigMap resource. |
required |
Returns:
Type | Description |
---|---|
Callable[[kfp.dsl._container_op.ContainerOp], NoneType] |
An OpFunc for mounting the ConfigMap. |
Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def mount_config_map_op(
config_map_name: str,
) -> Callable[[dsl.ContainerOp], None]:
"""Mounts all key-value pairs found in the named Kubernetes ConfigMap.
All key-value pairs in the ConfigMap are mounted as environment variables.
Args:
config_map_name: The name of the ConfigMap resource.
Returns:
An OpFunc for mounting the ConfigMap.
"""
def mount_config_map(container_op: dsl.ContainerOp) -> None:
"""Mounts all key-value pairs found in the Kubernetes ConfigMap.
Args:
container_op: The container op to mount the ConfigMap.
"""
config_map_ref = k8s_client.V1ConfigMapEnvSource(
name=config_map_name, optional=True
)
container_op.container.add_env_from(
k8s_client.V1EnvFromSource(config_map_ref=config_map_ref)
)
return mount_config_map
kubernetes
special
Kubernetes integration for Kubernetes-native orchestration.
The Kubernetes integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubernetes orchestrator with the CLI tool.
KubernetesIntegration (Integration)
Definition of Kubernetes integration for ZenML.
Source code in zenml/integrations/kubernetes/__init__.py
class KubernetesIntegration(Integration):
"""Definition of Kubernetes integration for ZenML."""
NAME = KUBERNETES
REQUIREMENTS = ["kubernetes==18.20.0"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Kubernetes integration.
Returns:
List of new stack component flavors.
"""
return [
FlavorWrapper(
name=KUBERNETES_METADATA_STORE_FLAVOR,
source="zenml.integrations.kubernetes.metadata_stores.KubernetesMetadataStore",
type=StackComponentType.METADATA_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=KUBERNETES_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.kubernetes.orchestrators.KubernetesOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the Kubernetes integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of new stack component flavors. |
Source code in zenml/integrations/kubernetes/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Kubernetes integration.
Returns:
List of new stack component flavors.
"""
return [
FlavorWrapper(
name=KUBERNETES_METADATA_STORE_FLAVOR,
source="zenml.integrations.kubernetes.metadata_stores.KubernetesMetadataStore",
type=StackComponentType.METADATA_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=KUBERNETES_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.kubernetes.orchestrators.KubernetesOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
),
]
metadata_stores
special
Initialization of the Kubernetes metadata store for ZenML.
kubernetes_metadata_store
Implementation of Kubernetes metadata store.
KubernetesMetadataStore (BaseMetadataStore)
pydantic-model
Kubernetes metadata store (MySQL database deployed in the cluster).
Attributes:
Name | Type | Description |
---|---|---|
deployment_name |
str |
Name of the Kubernetes deployment and corresponding
service/pod that will be created when calling |
kubernetes_context |
str |
Name of the Kubernetes context in which to deploy and provision the MySQL database. |
kubernetes_namespace |
str |
Name of the Kubernetes namespace. Defaults to "default". |
storage_capacity |
str |
Storage capacity of the metadata store.
Defaults to |
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
class KubernetesMetadataStore(BaseMetadataStore):
"""Kubernetes metadata store (MySQL database deployed in the cluster).
Attributes:
deployment_name: Name of the Kubernetes deployment and corresponding
service/pod that will be created when calling `provision()`.
kubernetes_context: Name of the Kubernetes context in which to deploy
and provision the MySQL database.
kubernetes_namespace: Name of the Kubernetes namespace.
Defaults to "default".
storage_capacity: Storage capacity of the metadata store.
Defaults to `"10Gi"` (=10GB).
"""
deployment_name: str
kubernetes_context: str
kubernetes_namespace: str = "zenml"
storage_capacity: str = "10Gi"
_k8s_core_api: k8s_client.CoreV1Api = None
_k8s_apps_api: k8s_client.AppsV1Api = None
FLAVOR: ClassVar[str] = KUBERNETES_METADATA_STORE_FLAVOR
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initiate the Pydantic object and initialize the Kubernetes clients.
Args:
*args: The positional arguments to pass to the Pydantic object.
**kwargs: The keyword arguments to pass to the Pydantic object.
"""
super().__init__(*args, **kwargs)
self._initialize_k8s_clients()
def _initialize_k8s_clients(self) -> None:
"""Initialize the Kubernetes clients."""
kube_utils.load_kube_config(context=self.kubernetes_context)
self._k8s_core_api = k8s_client.CoreV1Api()
self._k8s_apps_api = k8s_client.AppsV1Api()
@root_validator(skip_on_failure=False)
def check_required_attributes(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Pydantic root_validator.
This ensures that both `deployment_name` and `kubernetes_context` are
set and raises an error with a custom error message otherwise.
Args:
values: Values passed to the Pydantic constructor.
Raises:
StackComponentInterfaceError: if either `deployment_name` or
`kubernetes_context` is not defined.
Returns:
Values passed to the Pydantic constructor.
"""
usage_note = (
"Note: the `kubernetes` metadata store flavor is a special "
"subtype of the `mysql` metadata store that deploys a fresh "
"MySQL database within your Kubernetes cluster when running "
"`zenml stack up`. "
"If you already have a MySQL database running in your cluster "
"(or elsewhere), simply use the `mysql` metadata store flavor "
"instead."
)
for required_field in ("deployment_name", "kubernetes_context"):
if required_field not in values:
raise StackComponentInterfaceError(
f"Required field `{required_field}` missing for "
"`KubernetesMetadataStore`. " + usage_note
)
return values
@property
def deployment_exists(self) -> bool:
"""Check whether a MySQL deployment exists in the cluster.
Returns:
Whether a MySQL deployment exists in the cluster.
"""
resp = self._k8s_apps_api.list_namespaced_deployment(
namespace=self.kubernetes_namespace
)
for i in resp.items:
if i.metadata.name == self.deployment_name:
return True
return False
@property
def is_provisioned(self) -> bool:
"""If the component provisioned resources to run.
Checks whether the required MySQL deployment exists.
Returns:
True if the component provisioned resources to run.
"""
return super().is_provisioned and self.deployment_exists
@property
def is_running(self) -> bool:
"""If the component is running.
Returns:
True if `is_provisioned` else False.
"""
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
if not check_if_daemon_is_running(self._pid_file_path):
return False
else:
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
return self.is_provisioned
def provision(self) -> None:
"""Provision the metadata store.
Creates a deployment with a MySQL database running in it.
"""
logger.info("Provisioning Kubernetes MySQL metadata store...")
kube_utils.create_namespace(
core_api=self._k8s_core_api, namespace=self.kubernetes_namespace
)
kube_utils.create_mysql_deployment(
core_api=self._k8s_core_api,
apps_api=self._k8s_apps_api,
namespace=self.kubernetes_namespace,
storage_capacity=self.storage_capacity,
deployment_name=self.deployment_name,
)
# wait a bit, then make sure deployment pod is alive and running.
logger.info("Trying to reach Kubernetes MySQL metadata store pod...")
time.sleep(10)
kube_utils.wait_pod(
core_api=self._k8s_core_api,
pod_name=self.pod_name,
namespace=self.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_not_pending,
)
logger.info("Kubernetes MySQL metadata store pod is up and running.")
def deprovision(self) -> None:
"""Deprovision the metadata store by deleting the MySQL deployment."""
logger.info("Deleting Kubernetes MySQL metadata store...")
self.suspend()
kube_utils.delete_deployment(
apps_api=self._k8s_apps_api,
deployment_name=self.deployment_name,
namespace=self.kubernetes_namespace,
)
# TODO: code duplication with kubeflow metadata store below.
@property
def root_directory(self) -> str:
"""Returns path to the root directory for all files concerning this orchestrator.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
self.FLAVOR,
str(self.uuid),
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(
self.root_directory, DEFAULT_KUBERNETES_METADATA_DAEMON_PID_FILE
)
@property
def _log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path to the daemon log file.
"""
return os.path.join(
self.root_directory, DEFAULT_KUBERNETES_METADATA_DAEMON_LOG_FILE
)
def resume(self) -> None:
"""Resumes the metadata store."""
self.start_metadata_daemon()
self.wait_until_metadata_store_ready(
timeout=DEFAULT_KUBERNETES_METADATA_DAEMON_TIMEOUT
)
def suspend(self) -> None:
"""Suspends the metadata store."""
self.stop_metadata_daemon()
@property
def pod_name(self) -> str:
"""Name of the Kubernetes pod where the MySQL database is deployed.
Returns:
Name of the Kubernetes pod.
"""
pod_list = self._k8s_core_api.list_namespaced_pod(
namespace=self.kubernetes_namespace,
label_selector=f"app={self.deployment_name}",
)
return pod_list.items[0].metadata.name # type: ignore[no-any-return]
@property
def host(self) -> str:
"""Get the MySQL host required to access the metadata store.
This overwrites the MySQL host to use local host when
running outside of the cluster so we can access the metadata store
locally for post execution.
Raises:
RuntimeError: If the metadata store is not running.
Returns:
MySQL host.
"""
if kube_utils.is_inside_kubernetes():
return DEFAULT_KUBERNETES_MYSQL_HOST
if not self.is_running:
raise RuntimeError(
"The Kubernetes metadata daemon is not running. Please run the "
"following command to start it first:\n\n"
" 'zenml metadata-store up'\n"
)
return DEFAULT_KUBERNETES_MYSQL_LOCAL_HOST
@property
def port(self) -> int:
"""Get the MySQL port required to access the metadata store.
Returns:
int: MySQL port.
"""
return DEFAULT_KUBERNETES_MYSQL_PORT
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for the Kubernetes metadata store.
Returns:
The tfx metadata config.
"""
config = MySQLDatabaseConfig(
host=self.host,
port=self.port,
database=DEFAULT_KUBERNETES_MYSQL_DATABASE,
user=DEFAULT_KUBERNETES_MYSQL_USERNAME,
password=DEFAULT_KUBERNETES_MYSQL_PASSWORD,
)
connection_config = metadata_store_pb2.ConnectionConfig(mysql=config)
return connection_config
def start_metadata_daemon(self) -> None:
"""Starts a daemon process that forwards ports.
This is so the MySQL database in the Kubernetes cluster is accessible
on the localhost.
Raises:
ProvisioningError: if the daemon fails to start.
"""
command = [
"kubectl",
"--context",
self.kubernetes_context,
"--namespace",
self.kubernetes_namespace,
"port-forward",
f"svc/{self.deployment_name}",
f"{self.port}:{self.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubernetes Metadata locally, "
"please run '%s' in a separate command line shell.",
self.port,
" ".join(command),
)
elif not networking_utils.port_available(self.port):
raise ProvisioningError(
f"Unable to port-forward Kubernetes Metadata to local "
f"port {self.port} because the port is occupied. In order to "
f"access the Kubernetes Metadata locally, please "
f"change the metadata store configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubernetes metadata store pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Kubernetes Metadata daemon (check the daemon"
"logs at %s in case you're not able to access the pipeline"
"metadata).",
self._log_file,
)
def stop_metadata_daemon(self) -> None:
"""Stops the Kubernetes metadata daemon process if it is running."""
if sys.platform != "win32" and fileio.exists(self._pid_file_path):
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
def wait_until_metadata_store_ready(self, timeout: int) -> None:
"""Waits until the metadata store connection is ready.
Potentially an irrecoverable error could occur or the timeout could
expire, so it checks for this.
Args:
timeout: The maximum time to wait for the metadata store to be
ready.
Raises:
RuntimeError: if the metadata store is not ready after the timeout
"""
logger.info(
"Waiting for the Kubernetes metadata store to be ready (this "
"might take a few minutes)."
)
while True:
try:
# it doesn't matter what we call here as long as it exercises
# the MLMD connection
self.get_pipelines()
break
except Exception as e:
logger.info(
"The Kubernetes metadata store is not ready yet. Waiting "
"for 10 seconds..."
)
if timeout <= 0:
raise RuntimeError(
f"An unexpected error was encountered while waiting "
f"for the Kubernetes metadata store to be functional: "
f"{str(e)}"
) from e
timeout -= 10
time.sleep(10)
logger.info("The Kubernetes metadata store is functional.")
deployment_exists: bool
property
readonly
Check whether a MySQL deployment exists in the cluster.
Returns:
Type | Description |
---|---|
bool |
Whether a MySQL deployment exists in the cluster. |
host: str
property
readonly
Get the MySQL host required to access the metadata store.
This overwrites the MySQL host to use local host when running outside of the cluster so we can access the metadata store locally for post execution.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the metadata store is not running. |
Returns:
Type | Description |
---|---|
str |
MySQL host. |
is_provisioned: bool
property
readonly
If the component provisioned resources to run.
Checks whether the required MySQL deployment exists.
Returns:
Type | Description |
---|---|
bool |
True if the component provisioned resources to run. |
is_running: bool
property
readonly
If the component is running.
Returns:
Type | Description |
---|---|
bool |
True if |
pod_name: str
property
readonly
Name of the Kubernetes pod where the MySQL database is deployed.
Returns:
Type | Description |
---|---|
str |
Name of the Kubernetes pod. |
port: int
property
readonly
Get the MySQL port required to access the metadata store.
Returns:
Type | Description |
---|---|
int |
MySQL port. |
root_directory: str
property
readonly
Returns path to the root directory for all files concerning this orchestrator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
__init__(self, *args, **kwargs)
special
Initiate the Pydantic object and initialize the Kubernetes clients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
The positional arguments to pass to the Pydantic object. |
() |
**kwargs |
Any |
The keyword arguments to pass to the Pydantic object. |
{} |
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initiate the Pydantic object and initialize the Kubernetes clients.
Args:
*args: The positional arguments to pass to the Pydantic object.
**kwargs: The keyword arguments to pass to the Pydantic object.
"""
super().__init__(*args, **kwargs)
self._initialize_k8s_clients()
check_required_attributes(values)
classmethod
Pydantic root_validator.
This ensures that both deployment_name
and kubernetes_context
are
set and raises an error with a custom error message otherwise.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
Dict[str, Any] |
Values passed to the Pydantic constructor. |
required |
Exceptions:
Type | Description |
---|---|
StackComponentInterfaceError |
if either |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Values passed to the Pydantic constructor. |
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
@root_validator(skip_on_failure=False)
def check_required_attributes(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Pydantic root_validator.
This ensures that both `deployment_name` and `kubernetes_context` are
set and raises an error with a custom error message otherwise.
Args:
values: Values passed to the Pydantic constructor.
Raises:
StackComponentInterfaceError: if either `deployment_name` or
`kubernetes_context` is not defined.
Returns:
Values passed to the Pydantic constructor.
"""
usage_note = (
"Note: the `kubernetes` metadata store flavor is a special "
"subtype of the `mysql` metadata store that deploys a fresh "
"MySQL database within your Kubernetes cluster when running "
"`zenml stack up`. "
"If you already have a MySQL database running in your cluster "
"(or elsewhere), simply use the `mysql` metadata store flavor "
"instead."
)
for required_field in ("deployment_name", "kubernetes_context"):
if required_field not in values:
raise StackComponentInterfaceError(
f"Required field `{required_field}` missing for "
"`KubernetesMetadataStore`. " + usage_note
)
return values
deprovision(self)
Deprovision the metadata store by deleting the MySQL deployment.
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def deprovision(self) -> None:
"""Deprovision the metadata store by deleting the MySQL deployment."""
logger.info("Deleting Kubernetes MySQL metadata store...")
self.suspend()
kube_utils.delete_deployment(
apps_api=self._k8s_apps_api,
deployment_name=self.deployment_name,
namespace=self.kubernetes_namespace,
)
get_tfx_metadata_config(self)
Return tfx metadata config for the Kubernetes metadata store.
Returns:
Type | Description |
---|---|
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig] |
The tfx metadata config. |
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for the Kubernetes metadata store.
Returns:
The tfx metadata config.
"""
config = MySQLDatabaseConfig(
host=self.host,
port=self.port,
database=DEFAULT_KUBERNETES_MYSQL_DATABASE,
user=DEFAULT_KUBERNETES_MYSQL_USERNAME,
password=DEFAULT_KUBERNETES_MYSQL_PASSWORD,
)
connection_config = metadata_store_pb2.ConnectionConfig(mysql=config)
return connection_config
provision(self)
Provision the metadata store.
Creates a deployment with a MySQL database running in it.
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def provision(self) -> None:
"""Provision the metadata store.
Creates a deployment with a MySQL database running in it.
"""
logger.info("Provisioning Kubernetes MySQL metadata store...")
kube_utils.create_namespace(
core_api=self._k8s_core_api, namespace=self.kubernetes_namespace
)
kube_utils.create_mysql_deployment(
core_api=self._k8s_core_api,
apps_api=self._k8s_apps_api,
namespace=self.kubernetes_namespace,
storage_capacity=self.storage_capacity,
deployment_name=self.deployment_name,
)
# wait a bit, then make sure deployment pod is alive and running.
logger.info("Trying to reach Kubernetes MySQL metadata store pod...")
time.sleep(10)
kube_utils.wait_pod(
core_api=self._k8s_core_api,
pod_name=self.pod_name,
namespace=self.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_not_pending,
)
logger.info("Kubernetes MySQL metadata store pod is up and running.")
resume(self)
Resumes the metadata store.
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def resume(self) -> None:
"""Resumes the metadata store."""
self.start_metadata_daemon()
self.wait_until_metadata_store_ready(
timeout=DEFAULT_KUBERNETES_METADATA_DAEMON_TIMEOUT
)
start_metadata_daemon(self)
Starts a daemon process that forwards ports.
This is so the MySQL database in the Kubernetes cluster is accessible on the localhost.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
if the daemon fails to start. |
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def start_metadata_daemon(self) -> None:
"""Starts a daemon process that forwards ports.
This is so the MySQL database in the Kubernetes cluster is accessible
on the localhost.
Raises:
ProvisioningError: if the daemon fails to start.
"""
command = [
"kubectl",
"--context",
self.kubernetes_context,
"--namespace",
self.kubernetes_namespace,
"port-forward",
f"svc/{self.deployment_name}",
f"{self.port}:{self.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubernetes Metadata locally, "
"please run '%s' in a separate command line shell.",
self.port,
" ".join(command),
)
elif not networking_utils.port_available(self.port):
raise ProvisioningError(
f"Unable to port-forward Kubernetes Metadata to local "
f"port {self.port} because the port is occupied. In order to "
f"access the Kubernetes Metadata locally, please "
f"change the metadata store configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubernetes metadata store pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Kubernetes Metadata daemon (check the daemon"
"logs at %s in case you're not able to access the pipeline"
"metadata).",
self._log_file,
)
stop_metadata_daemon(self)
Stops the Kubernetes metadata daemon process if it is running.
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def stop_metadata_daemon(self) -> None:
"""Stops the Kubernetes metadata daemon process if it is running."""
if sys.platform != "win32" and fileio.exists(self._pid_file_path):
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
suspend(self)
Suspends the metadata store.
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def suspend(self) -> None:
"""Suspends the metadata store."""
self.stop_metadata_daemon()
wait_until_metadata_store_ready(self, timeout)
Waits until the metadata store connection is ready.
Potentially an irrecoverable error could occur or the timeout could expire, so it checks for this.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timeout |
int |
The maximum time to wait for the metadata store to be ready. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the metadata store is not ready after the timeout |
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def wait_until_metadata_store_ready(self, timeout: int) -> None:
"""Waits until the metadata store connection is ready.
Potentially an irrecoverable error could occur or the timeout could
expire, so it checks for this.
Args:
timeout: The maximum time to wait for the metadata store to be
ready.
Raises:
RuntimeError: if the metadata store is not ready after the timeout
"""
logger.info(
"Waiting for the Kubernetes metadata store to be ready (this "
"might take a few minutes)."
)
while True:
try:
# it doesn't matter what we call here as long as it exercises
# the MLMD connection
self.get_pipelines()
break
except Exception as e:
logger.info(
"The Kubernetes metadata store is not ready yet. Waiting "
"for 10 seconds..."
)
if timeout <= 0:
raise RuntimeError(
f"An unexpected error was encountered while waiting "
f"for the Kubernetes metadata store to be functional: "
f"{str(e)}"
) from e
timeout -= 10
time.sleep(10)
logger.info("The Kubernetes metadata store is functional.")
orchestrators
special
Kubernetes-native orchestration.
dag_runner
DAG (Directed Acyclic Graph) Runners.
NodeStatus (Enum)
Status of the execution of a node.
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class NodeStatus(Enum):
"""Status of the execution of a node."""
WAITING = "Waiting"
RUNNING = "Running"
COMPLETED = "Completed"
ThreadedDagRunner
Multi-threaded DAG Runner.
This class expects a DAG of strings in adjacency list representation, as
well as a custom run_fn
as input, then calls run_fn(node)
for each
string node in the DAG.
Steps that can be executed in parallel will be started in separate threads.
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class ThreadedDagRunner:
"""Multi-threaded DAG Runner.
This class expects a DAG of strings in adjacency list representation, as
well as a custom `run_fn` as input, then calls `run_fn(node)` for each
string node in the DAG.
Steps that can be executed in parallel will be started in separate threads.
"""
def __init__(
self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
) -> None:
"""Define attributes and initialize all nodes in waiting state.
Args:
dag: Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
`dag={2: [1], 3: [1], 4: [2, 3]}`
run_fn: A function `run_fn(node)` that runs a single node
"""
self.dag = dag
self.reversed_dag = reverse_dag(dag)
self.run_fn = run_fn
self.nodes = dag.keys()
self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
self._lock = threading.Lock()
def _can_run(self, node: str) -> bool:
"""Determine whether a node is ready to be run.
This is the case if the node has not run yet and all of its upstream
node have already completed.
Args:
node: The node.
Returns:
True if the node can run else False.
"""
# Check that node has not run yet.
if not self.node_states[node] == NodeStatus.WAITING:
return False
# Check that all upstream nodes of this node have already completed.
for upstream_node in self.dag[node]:
if not self.node_states[upstream_node] == NodeStatus.COMPLETED:
return False
return True
def _run_node(self, node: str) -> None:
"""Run a single node.
Calls the user-defined run_fn, then calls `self._finish_node`.
Args:
node: The node.
"""
self.run_fn(node)
self._finish_node(node)
def _run_node_in_thread(self, node: str) -> threading.Thread:
"""Run a single node in a separate thread.
First updates the node status to running.
Then calls self._run_node() in a new thread and returns the thread.
Args:
node: The node.
Returns:
The thread in which the node was run.
"""
# Update node status to running.
assert self.node_states[node] == NodeStatus.WAITING
with self._lock:
self.node_states[node] = NodeStatus.RUNNING
# Run node in new thread.
thread = threading.Thread(target=self._run_node, args=(node,))
thread.start()
return thread
def _finish_node(self, node: str) -> None:
"""Finish a node run.
First updates the node status to completed.
Then starts all other nodes that can now be run and waits for them.
Args:
node: The node.
"""
# Update node status to completed.
assert self.node_states[node] == NodeStatus.RUNNING
with self._lock:
self.node_states[node] = NodeStatus.COMPLETED
# Run downstream nodes.
threads = []
for downstram_node in self.reversed_dag[node]:
if self._can_run(downstram_node):
thread = self._run_node_in_thread(downstram_node)
threads.append(thread)
# Wait for all downstream nodes to complete.
for thread in threads:
thread.join()
def run(self) -> None:
"""Call `self.run_fn` on all nodes in `self.dag`.
The order of execution is determined using topological sort.
Each node is run in a separate thread to enable parallelism.
"""
# Run all nodes that can be started immediately.
# These will, in turn, start other nodes once all of their respective
# upstream nodes have completed.
threads = []
for node in self.nodes:
if self._can_run(node):
thread = self._run_node_in_thread(node)
threads.append(thread)
# Wait till all nodes have completed.
for thread in threads:
thread.join()
# Make sure all nodes were run, otherwise print a warning.
for node in self.nodes:
if self.node_states[node] == NodeStatus.WAITING:
upstream_nodes = self.dag[node]
logger.warning(
f"Node `{node}` was never run, because it was still"
f" waiting for the following nodes: `{upstream_nodes}`."
)
__init__(self, dag, run_fn)
special
Define attributes and initialize all nodes in waiting state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag |
Dict[str, List[str]] |
Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
|
required |
run_fn |
Callable[[str], Any] |
A function |
required |
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def __init__(
self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
) -> None:
"""Define attributes and initialize all nodes in waiting state.
Args:
dag: Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
`dag={2: [1], 3: [1], 4: [2, 3]}`
run_fn: A function `run_fn(node)` that runs a single node
"""
self.dag = dag
self.reversed_dag = reverse_dag(dag)
self.run_fn = run_fn
self.nodes = dag.keys()
self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
self._lock = threading.Lock()
run(self)
Call self.run_fn
on all nodes in self.dag
.
The order of execution is determined using topological sort. Each node is run in a separate thread to enable parallelism.
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def run(self) -> None:
"""Call `self.run_fn` on all nodes in `self.dag`.
The order of execution is determined using topological sort.
Each node is run in a separate thread to enable parallelism.
"""
# Run all nodes that can be started immediately.
# These will, in turn, start other nodes once all of their respective
# upstream nodes have completed.
threads = []
for node in self.nodes:
if self._can_run(node):
thread = self._run_node_in_thread(node)
threads.append(thread)
# Wait till all nodes have completed.
for thread in threads:
thread.join()
# Make sure all nodes were run, otherwise print a warning.
for node in self.nodes:
if self.node_states[node] == NodeStatus.WAITING:
upstream_nodes = self.dag[node]
logger.warning(
f"Node `{node}` was never run, because it was still"
f" waiting for the following nodes: `{upstream_nodes}`."
)
reverse_dag(dag)
Reverse a DAG.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag |
Dict[str, List[str]] |
Adjacency list representation of a DAG. |
required |
Returns:
Type | Description |
---|---|
Dict[str, List[str]] |
Adjacency list representation of the reversed DAG. |
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def reverse_dag(dag: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Reverse a DAG.
Args:
dag: Adjacency list representation of a DAG.
Returns:
Adjacency list representation of the reversed DAG.
"""
reversed_dag = defaultdict(list)
# Reverse all edges in the graph.
for node, upstream_nodes in dag.items():
for upstream_node in upstream_nodes:
reversed_dag[upstream_node].append(node)
# Add nodes without incoming edges back in.
for node in dag:
if node not in reversed_dag:
reversed_dag[node] = []
return reversed_dag
kube_utils
Utilities for Kubernetes related functions.
Internal interface: no backwards compatibility guarantees. Adjusted from https://github.com/tensorflow/tfx/blob/master/tfx/utils/kube_utils.py.
PodPhase (Enum)
Phase of the Kubernetes pod.
Pod phases are defined in https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
class PodPhase(enum.Enum):
"""Phase of the Kubernetes pod.
Pod phases are defined in
https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
"""
PENDING = "Pending"
RUNNING = "Running"
SUCCEEDED = "Succeeded"
FAILED = "Failed"
UNKNOWN = "Unknown"
create_edit_service_account(core_api, rbac_api, service_account_name, namespace, cluster_role_binding_name='zenml-edit')
Create a new Kubernetes service account with "edit" rights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of Core V1 API of Kubernetes API. |
required |
rbac_api |
RbacAuthorizationV1Api |
Client of Rbac Authorization V1 API of Kubernetes API. |
required |
service_account_name |
str |
Name of the service account. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
required |
cluster_role_binding_name |
str |
Name of the cluster role binding. Defaults to "zenml-edit". |
'zenml-edit' |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_edit_service_account(
core_api: k8s_client.CoreV1Api,
rbac_api: k8s_client.RbacAuthorizationV1Api,
service_account_name: str,
namespace: str,
cluster_role_binding_name: str = "zenml-edit",
) -> None:
"""Create a new Kubernetes service account with "edit" rights.
Args:
core_api: Client of Core V1 API of Kubernetes API.
rbac_api: Client of Rbac Authorization V1 API of Kubernetes API.
service_account_name: Name of the service account.
namespace: Kubernetes namespace. Defaults to "default".
cluster_role_binding_name: Name of the cluster role binding.
Defaults to "zenml-edit".
"""
crb_manifest = build_cluster_role_binding_manifest_for_service_account(
name=cluster_role_binding_name,
role_name="edit",
service_account_name=service_account_name,
namespace=namespace,
)
_if_not_exists(rbac_api.create_cluster_role_binding)(body=crb_manifest)
sa_manifest = build_service_account_manifest(
name=service_account_name, namespace=namespace
)
_if_not_exists(core_api.create_namespaced_service_account)(
namespace=namespace,
body=sa_manifest,
)
create_mysql_deployment(core_api, apps_api, deployment_name, namespace, storage_capacity='10Gi', volume_name='mysql-pv-volume', volume_claim_name='mysql-pv-claim')
Create a Kubernetes deployment with a MySQL database running on it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of Core V1 API of Kubernetes API. |
required |
apps_api |
AppsV1Api |
Client of Apps V1 API of Kubernetes API. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
required |
storage_capacity |
str |
Storage capacity of the database.
Defaults to |
'10Gi' |
deployment_name |
str |
Name of the deployment. Defaults to "mysql". |
required |
volume_name |
str |
Name of the persistent volume.
Defaults to |
'mysql-pv-volume' |
volume_claim_name |
str |
Name of the persistent volume claim.
Defaults to |
'mysql-pv-claim' |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_mysql_deployment(
core_api: k8s_client.CoreV1Api,
apps_api: k8s_client.AppsV1Api,
deployment_name: str,
namespace: str,
storage_capacity: str = "10Gi",
volume_name: str = "mysql-pv-volume",
volume_claim_name: str = "mysql-pv-claim",
) -> None:
"""Create a Kubernetes deployment with a MySQL database running on it.
Args:
core_api: Client of Core V1 API of Kubernetes API.
apps_api: Client of Apps V1 API of Kubernetes API.
namespace: Kubernetes namespace. Defaults to "default".
storage_capacity: Storage capacity of the database.
Defaults to `"10Gi"`.
deployment_name: Name of the deployment. Defaults to "mysql".
volume_name: Name of the persistent volume.
Defaults to `"mysql-pv-volume"`.
volume_claim_name: Name of the persistent volume claim.
Defaults to `"mysql-pv-claim"`.
"""
pvc_manifest = build_persistent_volume_claim_manifest(
name=volume_claim_name,
namespace=namespace,
storage_request=storage_capacity,
)
_if_not_exists(core_api.create_namespaced_persistent_volume_claim)(
namespace=namespace,
body=pvc_manifest,
)
pv_manifest = build_persistent_volume_manifest(
name=volume_name, storage_capacity=storage_capacity
)
_if_not_exists(core_api.create_persistent_volume)(body=pv_manifest)
deployment_manifest = build_mysql_deployment_manifest(
name=deployment_name,
namespace=namespace,
pv_claim_name=volume_claim_name,
)
_if_not_exists(apps_api.create_namespaced_deployment)(
body=deployment_manifest, namespace=namespace
)
service_manifest = build_mysql_service_manifest(
name=deployment_name, namespace=namespace
)
_if_not_exists(core_api.create_namespaced_service)(
namespace=namespace, body=service_manifest
)
create_namespace(core_api, namespace)
Create a Kubernetes namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of Core V1 API of Kubernetes API. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
required |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_namespace(core_api: k8s_client.CoreV1Api, namespace: str) -> None:
"""Create a Kubernetes namespace.
Args:
core_api: Client of Core V1 API of Kubernetes API.
namespace: Kubernetes namespace. Defaults to "default".
"""
manifest = build_namespace_manifest(namespace)
_if_not_exists(core_api.create_namespace)(body=manifest)
delete_deployment(apps_api, deployment_name, namespace)
Delete a Kubernetes deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
apps_api |
AppsV1Api |
Client of Apps V1 API of Kubernetes API. |
required |
deployment_name |
str |
Name of the deployment to be deleted. |
required |
namespace |
str |
Kubernetes namespace containing the deployment. |
required |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def delete_deployment(
apps_api: k8s_client.AppsV1Api, deployment_name: str, namespace: str
) -> None:
"""Delete a Kubernetes deployment.
Args:
apps_api: Client of Apps V1 API of Kubernetes API.
deployment_name: Name of the deployment to be deleted.
namespace: Kubernetes namespace containing the deployment.
"""
options = k8s_client.V1DeleteOptions()
apps_api.delete_namespaced_deployment(
name=deployment_name,
namespace=namespace,
body=options,
propagation_policy="Foreground",
)
get_pod(core_api, pod_name, namespace)
Get a pod from Kubernetes metadata API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of |
required |
pod_name |
str |
The name of the pod. |
required |
namespace |
str |
The namespace of the pod. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
When it sees unexpected errors from Kubernetes API. |
Returns:
Type | Description |
---|---|
Optional[kubernetes.client.models.v1_pod.V1Pod] |
The found pod object. None if it's not found. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def get_pod(
core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str
) -> Optional[k8s_client.V1Pod]:
"""Get a pod from Kubernetes metadata API.
Args:
core_api: Client of `CoreV1Api` of Kubernetes API.
pod_name: The name of the pod.
namespace: The namespace of the pod.
Raises:
RuntimeError: When it sees unexpected errors from Kubernetes API.
Returns:
The found pod object. None if it's not found.
"""
try:
return core_api.read_namespaced_pod(name=pod_name, namespace=namespace)
except k8s_client.rest.ApiException as e:
if e.status == 404:
return None
raise RuntimeError from e
is_inside_kubernetes()
Check whether we are inside a Kubernetes cluster or on a remote host.
Returns:
Type | Description |
---|---|
bool |
True if inside a Kubernetes cluster, else False. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def is_inside_kubernetes() -> bool:
"""Check whether we are inside a Kubernetes cluster or on a remote host.
Returns:
True if inside a Kubernetes cluster, else False.
"""
try:
k8s_config.load_incluster_config()
return True
except k8s_config.ConfigException:
return False
load_kube_config(context=None)
Load the Kubernetes client config.
Depending on the environment (whether it is inside the running Kubernetes cluster or remote host), different location will be searched for the config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
Optional[str] |
Name of the Kubernetes context. If not provided, uses the currently active context. |
None |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def load_kube_config(context: Optional[str] = None) -> None:
"""Load the Kubernetes client config.
Depending on the environment (whether it is inside the running Kubernetes
cluster or remote host), different location will be searched for the config
file.
Args:
context: Name of the Kubernetes context. If not provided, uses the
currently active context.
"""
try:
k8s_config.load_incluster_config()
except k8s_config.ConfigException:
k8s_config.load_kube_config(context=context)
pod_failed(pod)
Check if pod status is 'Failed'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod |
V1Pod |
Kubernetes pod. |
required |
Returns:
Type | Description |
---|---|
bool |
True if pod status is 'Failed' else False. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_failed(pod: k8s_client.V1Pod) -> bool:
"""Check if pod status is 'Failed'.
Args:
pod: Kubernetes pod.
Returns:
True if pod status is 'Failed' else False.
"""
return pod.status.phase == PodPhase.FAILED.value # type: ignore[no-any-return]
pod_is_done(pod)
Check if pod status is 'Succeeded'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod |
V1Pod |
Kubernetes pod. |
required |
Returns:
Type | Description |
---|---|
bool |
True if pod status is 'Succeeded' else False. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_done(pod: k8s_client.V1Pod) -> bool:
"""Check if pod status is 'Succeeded'.
Args:
pod: Kubernetes pod.
Returns:
True if pod status is 'Succeeded' else False.
"""
return pod.status.phase == PodPhase.SUCCEEDED.value # type: ignore[no-any-return]
pod_is_not_pending(pod)
Check if pod status is not 'Pending'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod |
V1Pod |
Kubernetes pod. |
required |
Returns:
Type | Description |
---|---|
bool |
False if the pod status is 'Pending' else True. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_not_pending(pod: k8s_client.V1Pod) -> bool:
"""Check if pod status is not 'Pending'.
Args:
pod: Kubernetes pod.
Returns:
False if the pod status is 'Pending' else True.
"""
return pod.status.phase != PodPhase.PENDING.value # type: ignore[no-any-return]
sanitize_pod_name(pod_name)
Sanitize pod names so they conform to Kubernetes pod naming convention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod_name |
str |
Arbitrary input pod name. |
required |
Returns:
Type | Description |
---|---|
str |
Sanitized pod name. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def sanitize_pod_name(pod_name: str) -> str:
"""Sanitize pod names so they conform to Kubernetes pod naming convention.
Args:
pod_name: Arbitrary input pod name.
Returns:
Sanitized pod name.
"""
pod_name = re.sub(r"[^a-z0-9-]", "-", pod_name.lower())
pod_name = re.sub(r"^[-]+", "", pod_name)
return re.sub(r"[-]+", "-", pod_name)
wait_pod(core_api, pod_name, namespace, exit_condition_lambda, timeout_sec=0, exponential_backoff=False, stream_logs=False)
Wait for a pod to meet an exit condition.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of |
required |
pod_name |
str |
The name of the pod. |
required |
namespace |
str |
The namespace of the pod. |
required |
exit_condition_lambda |
Callable[[kubernetes.client.models.v1_pod.V1Pod], bool] |
A lambda which will be called periodically to wait for a pod to exit. The function returns True to exit. |
required |
timeout_sec |
int |
Timeout in seconds to wait for pod to reach exit condition, or 0 to wait for an unlimited duration. Defaults to unlimited. |
0 |
exponential_backoff |
bool |
Whether to use exponential back off for polling. Defaults to False. |
False |
stream_logs |
bool |
Whether to stream the pod logs to
|
False |
Exceptions:
Type | Description |
---|---|
RuntimeError |
when the function times out. |
Returns:
Type | Description |
---|---|
V1Pod |
The pod object which meets the exit condition. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def wait_pod(
core_api: k8s_client.CoreV1Api,
pod_name: str,
namespace: str,
exit_condition_lambda: Callable[[k8s_client.V1Pod], bool],
timeout_sec: int = 0,
exponential_backoff: bool = False,
stream_logs: bool = False,
) -> k8s_client.V1Pod:
"""Wait for a pod to meet an exit condition.
Args:
core_api: Client of `CoreV1Api` of Kubernetes API.
pod_name: The name of the pod.
namespace: The namespace of the pod.
exit_condition_lambda: A lambda
which will be called periodically to wait for a pod to exit. The
function returns True to exit.
timeout_sec: Timeout in seconds to wait for pod to reach exit
condition, or 0 to wait for an unlimited duration.
Defaults to unlimited.
exponential_backoff: Whether to use exponential back off for polling.
Defaults to False.
stream_logs: Whether to stream the pod logs to
`zenml.logger.info()`. Defaults to False.
Raises:
RuntimeError: when the function times out.
Returns:
The pod object which meets the exit condition.
"""
start_time = datetime.datetime.utcnow()
# Link to exponential back-off algorithm used here:
# https://cloud.google.com/storage/docs/exponential-backoff
backoff_interval = 1
maximum_backoff = 32
logged_lines = 0
while True:
resp = get_pod(core_api, pod_name, namespace)
# Stream logs to `zenml.logger.info()`.
# TODO: can we do this without parsing all logs every time?
if stream_logs and pod_is_not_pending(resp):
response = core_api.read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
)
logs = response.splitlines()
if len(logs) > logged_lines:
for line in logs[logged_lines:]:
logger.info(line)
logged_lines = len(logs)
# Raise an error if the pod failed.
if pod_failed(resp):
raise RuntimeError(f"Pod `{namespace}:{pod_name}` failed.")
# Check if pod is in desired state (e.g. finished / running / ...).
if exit_condition_lambda(resp):
return resp
# Check if wait timed out.
elapse_time = datetime.datetime.utcnow() - start_time
if elapse_time.seconds >= timeout_sec and timeout_sec != 0:
raise RuntimeError(
f"Waiting for pod `{namespace}:{pod_name}` timed out after "
f"{timeout_sec} seconds."
)
# Wait (using exponential backoff).
time.sleep(backoff_interval)
if exponential_backoff and backoff_interval < maximum_backoff:
backoff_interval *= 2
kubernetes_orchestrator
Kubernetes-native orchestrator.
KubernetesOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator for running ZenML pipelines using native Kubernetes.
Attributes:
Name | Type | Description |
---|---|---|
custom_docker_base_image_name |
Optional[str] |
Name of a Docker image that should be used as the base for the image that will be run on Kubernetes pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML Docker images found here: https://hub.docker.com/r/zenmldocker/zenml/ |
kubernetes_context |
Optional[str] |
Optional name of a Kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running |
kubernetes_namespace |
str |
Name of the Kubernetes namespace to be used.
If not provided, |
synchronous |
bool |
If |
skip_config_loading |
bool |
If |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
class KubernetesOrchestrator(BaseOrchestrator):
"""Orchestrator for running ZenML pipelines using native Kubernetes.
Attributes:
custom_docker_base_image_name: Name of a Docker image that should be
used as the base for the image that will be run on Kubernetes pods.
If no custom image is given, a basic image of the active ZenML
version will be used.
**Note**: This image needs to have ZenML installed,
otherwise the pipeline execution will fail. For that reason, you
might want to extend the ZenML Docker images found here:
https://hub.docker.com/r/zenmldocker/zenml/
kubernetes_context: Optional name of a Kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running `kubectl config
current-context`.
kubernetes_namespace: Name of the Kubernetes namespace to be used.
If not provided, `default` namespace will be used.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on Kubernetes.
skip_config_loading: If `True`, don't load the Kubernetes context and
clients. This is only useful for unit testing.
"""
custom_docker_base_image_name: Optional[str] = None
kubernetes_context: Optional[str] = None
kubernetes_namespace: str = "zenml"
synchronous: bool = False
skip_config_loading: bool = False
_k8s_core_api: k8s_client.CoreV1Api = None
_k8s_batch_api: k8s_client.BatchV1beta1Api = None
_k8s_rbac_api: k8s_client.RbacAuthorizationV1Api = None
FLAVOR: ClassVar[str] = KUBERNETES_ORCHESTRATOR_FLAVOR
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the Pydantic object the Kubernetes clients.
Args:
*args: The positional arguments to pass to the Pydantic object.
**kwargs: The keyword arguments to pass to the Pydantic object.
"""
super().__init__(*args, **kwargs)
self._initialize_k8s_clients()
def _initialize_k8s_clients(self) -> None:
"""Initialize the Kubernetes clients."""
if self.skip_config_loading:
return
kube_utils.load_kube_config(context=self.kubernetes_context)
self._k8s_core_api = k8s_client.CoreV1Api()
self._k8s_batch_api = k8s_client.BatchV1beta1Api()
self._k8s_rbac_api = k8s_client.RbacAuthorizationV1Api()
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
"""Get list of configured Kubernetes contexts and the active context.
Raises:
RuntimeError: if the Kubernetes configuration cannot be loaded.
Returns:
context_name: List of configured Kubernetes contexts
active_context_name: Name of the active Kubernetes context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException as e:
raise RuntimeError(
"Could not load the Kubernetes configuration"
) from e
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
@property
def validator(self) -> Optional[StackValidator]:
"""Defines the validator that checks whether the stack is valid.
Returns:
Stack validator.
"""
def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
"""Validates that the stack contains no local components.
Args:
stack: The stack.
Returns:
Whether the stack is valid or not.
An explanation why the stack is invalid, if applicable.
"""
container_registry = stack.container_registry
# should not happen, because the stack validation takes care of
# this, but just in case
assert container_registry is not None
if not self.skip_config_loading:
contexts, active_context = self.get_kubernetes_contexts()
if self.kubernetes_context not in contexts:
return False, (
f"Could not find a Kubernetes context named "
f"'{self.kubernetes_context}' in the local Kubernetes "
f"configuration. Please make sure that the Kubernetes "
f"cluster is running and that the kubeconfig file is "
f"configured correctly. To list all configured "
f"contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
)
if self.kubernetes_context != active_context:
logger.warning(
f"The Kubernetes context '{self.kubernetes_context}' "
f"configured for the Kubernetes orchestrator is not "
f"the same as the active context in the local "
f"Kubernetes configuration. If this is not deliberate,"
f" you should update the orchestrator's "
f"`kubernetes_context` field by running:\n\n"
f" `zenml orchestrator update {self.name} "
f"--kubernetes_context={active_context}`\n"
f"To list all configured contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
f"To set the active context to be the same as the one "
f"configured in the Kubernetes orchestrator and "
f"silence this warning, run:\n\n"
f" `kubectl config use-context "
f"{self.kubernetes_context}`\n"
)
# Check that all stack components are non-local.
for stack_comp in stack.components.values():
if stack_comp.local_path:
return False, (
f"The Kubernetes orchestrator currently only supports "
f"remote stacks, but the '{stack_comp.name}' "
f"{stack_comp.TYPE.value} is a local component. "
f"Please make sure to only use non-local stack "
f"components with a Kubernetes orchestrator."
)
# if the orchestrator is remote, the container registry must
# also be remote.
if container_registry.is_local:
return False, (
f"The Kubernetes orchestrator requires a remote container "
f"registry, but the '{container_registry.name}' container "
f"registry of your active stack points to a local URI "
f"'{container_registry.uri}'. Please make sure stacks "
f"with a Kubernetes orchestrator always contain remote "
f"container registries."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_local_requirements,
)
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Return the full Docker image name including registry and tag.
Args:
pipeline_name: Name of a ZenML pipeline.
Returns:
Docker image name.
"""
container_registry = Repository().active_stack.container_registry
assert container_registry
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/zenml-kubernetes:{pipeline_name}"
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Build a Docker image and upload it to the container registry.
Args:
pipeline: A ZenML pipeline.
stack: A ZenML stack.
runtime_configuration: The runtime configuration of the pipeline.
"""
from zenml.utils import docker_utils
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug("Kubernetes container requirements: %s", requirements)
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
)
assert stack.container_registry # should never happen due to validation
stack.container_registry.push_image(image_name)
# Store the Docker image digest in the runtime configuration so it gets
# tracked in the ZenStore
image_digest = docker_utils.get_image_digest(image_name) or image_name
runtime_configuration["docker_image"] = image_digest
def prepare_or_run_pipeline(
self,
sorted_steps: List["BaseStep"],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Run pipeline in Kubernetes.
Args:
sorted_steps: List of steps in execution order.
pipeline: ZenML pipeline.
pb2_pipeline: ZenML pipeline in TFX pb2 format.
stack: ZenML stack.
runtime_configuration: The runtime configuration of the pipeline.
Raises:
RuntimeError: If trying to run from a Jupyter notebook.
"""
# First check whether the code is running in a notebook.
if Environment.in_notebook():
raise RuntimeError(
"The Kubernetes orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
assert runtime_configuration.run_name, "Run name must be set"
for step in sorted_steps:
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Kubernetes orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
run_name = runtime_configuration.run_name
pipeline_name = pipeline.name
pod_name = kube_utils.sanitize_pod_name(run_name)
# Get Docker image name (for all pods).
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
# Get pipeline DAG as dict {"step": ["upstream_step_1", ...], ...}
pipeline_dag: Dict[str, List[str]] = {
step.name: self.get_upstream_step_names(step, pb2_pipeline)
for step in sorted_steps
}
# Build entrypoint command and args for the orchestrator pod.
# This will internally also build the command/args for all step pods.
command = (
KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
)
args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
run_name=run_name,
pipeline_name=pipeline_name,
image_name=image_name,
kubernetes_namespace=self.kubernetes_namespace,
pb2_pipeline=pb2_pipeline,
sorted_steps=sorted_steps,
pipeline_dag=pipeline_dag,
)
# Authorize pod to run Kubernetes commands inside the cluster.
service_account_name = "zenml-service-account"
kube_utils.create_edit_service_account(
core_api=self._k8s_core_api,
rbac_api=self._k8s_rbac_api,
service_account_name=service_account_name,
namespace=self.kubernetes_namespace,
)
# Schedule as CRON job if CRON schedule is given.
if runtime_configuration.schedule:
if not runtime_configuration.schedule.cron_expression:
raise RuntimeError(
"The Kubernetes orchestrator only supports scheduling via "
"CRON jobs, but the run was configured with a manual "
"schedule. Use `Schedule(cron_expression=...)` instead."
)
cron_expression = runtime_configuration.schedule.cron_expression
cron_job_manifest = build_cron_job_manifest(
cron_expression=cron_expression,
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_batch_api.create_namespaced_cron_job(
body=cron_job_manifest, namespace=self.kubernetes_namespace
)
logger.info(
f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
f'`"{cron_expression}"`.'
)
return
# Create and run the orchestrator pod.
pod_manifest = build_pod_manifest(
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_core_api.create_namespaced_pod(
namespace=self.kubernetes_namespace,
body=pod_manifest,
)
# Wait for the orchestrator pod to finish and stream logs.
if self.synchronous:
logger.info("Waiting for Kubernetes orchestrator pod...")
kube_utils.wait_pod(
core_api=self._k8s_core_api,
pod_name=pod_name,
namespace=self.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_done,
stream_logs=True,
)
else:
logger.info(
f"Orchestration started asynchronously in pod "
f"`{self.kubernetes_namespace}:{pod_name}`. "
f"Run the following command to inspect the logs: "
f"`kubectl logs {pod_name} -n {self.kubernetes_namespace}`."
)
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Defines the validator that checks whether the stack is valid.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
Stack validator. |
__init__(self, *args, **kwargs)
special
Initialize the Pydantic object the Kubernetes clients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
The positional arguments to pass to the Pydantic object. |
() |
**kwargs |
Any |
The keyword arguments to pass to the Pydantic object. |
{} |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the Pydantic object the Kubernetes clients.
Args:
*args: The positional arguments to pass to the Pydantic object.
**kwargs: The keyword arguments to pass to the Pydantic object.
"""
super().__init__(*args, **kwargs)
self._initialize_k8s_clients()
get_docker_image_name(self, pipeline_name)
Return the full Docker image name including registry and tag.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of a ZenML pipeline. |
required |
Returns:
Type | Description |
---|---|
str |
Docker image name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
"""Return the full Docker image name including registry and tag.
Args:
pipeline_name: Name of a ZenML pipeline.
Returns:
Docker image name.
"""
container_registry = Repository().active_stack.container_registry
assert container_registry
registry_uri = container_registry.uri.rstrip("/")
return f"{registry_uri}/zenml-kubernetes:{pipeline_name}"
get_kubernetes_contexts(self)
Get list of configured Kubernetes contexts and the active context.
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the Kubernetes configuration cannot be loaded. |
Returns:
Type | Description |
---|---|
context_name |
List of configured Kubernetes contexts active_context_name: Name of the active Kubernetes context. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
"""Get list of configured Kubernetes contexts and the active context.
Raises:
RuntimeError: if the Kubernetes configuration cannot be loaded.
Returns:
context_name: List of configured Kubernetes contexts
active_context_name: Name of the active Kubernetes context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException as e:
raise RuntimeError(
"Could not load the Kubernetes configuration"
) from e
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)
Run pipeline in Kubernetes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sorted_steps |
List[BaseStep] |
List of steps in execution order. |
required |
pipeline |
BasePipeline |
ZenML pipeline. |
required |
pb2_pipeline |
Pipeline |
ZenML pipeline in TFX pb2 format. |
required |
stack |
Stack |
ZenML stack. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration of the pipeline. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If trying to run from a Jupyter notebook. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_or_run_pipeline(
self,
sorted_steps: List["BaseStep"],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Run pipeline in Kubernetes.
Args:
sorted_steps: List of steps in execution order.
pipeline: ZenML pipeline.
pb2_pipeline: ZenML pipeline in TFX pb2 format.
stack: ZenML stack.
runtime_configuration: The runtime configuration of the pipeline.
Raises:
RuntimeError: If trying to run from a Jupyter notebook.
"""
# First check whether the code is running in a notebook.
if Environment.in_notebook():
raise RuntimeError(
"The Kubernetes orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
assert runtime_configuration.run_name, "Run name must be set"
for step in sorted_steps:
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Kubernetes orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
run_name = runtime_configuration.run_name
pipeline_name = pipeline.name
pod_name = kube_utils.sanitize_pod_name(run_name)
# Get Docker image name (for all pods).
image_name = self.get_docker_image_name(pipeline.name)
image_name = get_image_digest(image_name) or image_name
# Get pipeline DAG as dict {"step": ["upstream_step_1", ...], ...}
pipeline_dag: Dict[str, List[str]] = {
step.name: self.get_upstream_step_names(step, pb2_pipeline)
for step in sorted_steps
}
# Build entrypoint command and args for the orchestrator pod.
# This will internally also build the command/args for all step pods.
command = (
KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
)
args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
run_name=run_name,
pipeline_name=pipeline_name,
image_name=image_name,
kubernetes_namespace=self.kubernetes_namespace,
pb2_pipeline=pb2_pipeline,
sorted_steps=sorted_steps,
pipeline_dag=pipeline_dag,
)
# Authorize pod to run Kubernetes commands inside the cluster.
service_account_name = "zenml-service-account"
kube_utils.create_edit_service_account(
core_api=self._k8s_core_api,
rbac_api=self._k8s_rbac_api,
service_account_name=service_account_name,
namespace=self.kubernetes_namespace,
)
# Schedule as CRON job if CRON schedule is given.
if runtime_configuration.schedule:
if not runtime_configuration.schedule.cron_expression:
raise RuntimeError(
"The Kubernetes orchestrator only supports scheduling via "
"CRON jobs, but the run was configured with a manual "
"schedule. Use `Schedule(cron_expression=...)` instead."
)
cron_expression = runtime_configuration.schedule.cron_expression
cron_job_manifest = build_cron_job_manifest(
cron_expression=cron_expression,
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_batch_api.create_namespaced_cron_job(
body=cron_job_manifest, namespace=self.kubernetes_namespace
)
logger.info(
f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
f'`"{cron_expression}"`.'
)
return
# Create and run the orchestrator pod.
pod_manifest = build_pod_manifest(
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_core_api.create_namespaced_pod(
namespace=self.kubernetes_namespace,
body=pod_manifest,
)
# Wait for the orchestrator pod to finish and stream logs.
if self.synchronous:
logger.info("Waiting for Kubernetes orchestrator pod...")
kube_utils.wait_pod(
core_api=self._k8s_core_api,
pod_name=pod_name,
namespace=self.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_done,
stream_logs=True,
)
else:
logger.info(
f"Orchestration started asynchronously in pod "
f"`{self.kubernetes_namespace}:{pod_name}`. "
f"Run the following command to inspect the logs: "
f"`kubectl logs {pod_name} -n {self.kubernetes_namespace}`."
)
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Build a Docker image and upload it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
A ZenML pipeline. |
required |
stack |
Stack |
A ZenML stack. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration of the pipeline. |
required |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Build a Docker image and upload it to the container registry.
Args:
pipeline: A ZenML pipeline.
stack: A ZenML stack.
runtime_configuration: The runtime configuration of the pipeline.
"""
from zenml.utils import docker_utils
image_name = self.get_docker_image_name(pipeline.name)
requirements = {*stack.requirements(), *pipeline.requirements}
logger.debug("Kubernetes container requirements: %s", requirements)
docker_utils.build_docker_image(
build_context_path=get_source_root_path(),
image_name=image_name,
dockerignore_path=pipeline.dockerignore_file,
requirements=requirements,
base_image=self.custom_docker_base_image_name,
)
assert stack.container_registry # should never happen due to validation
stack.container_registry.push_image(image_name)
# Store the Docker image digest in the runtime configuration so it gets
# tracked in the ZenStore
image_digest = docker_utils.get_image_digest(image_name) or image_name
runtime_configuration["docker_image"] = image_digest
kubernetes_orchestrator_entrypoint
Entrypoint of the Kubernetes master/orchestrator pod.
main()
Entrypoint of the k8s master/orchestrator pod.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def main() -> None:
"""Entrypoint of the k8s master/orchestrator pod."""
# Log to the container's stdout so it can be streamed by the client.
logger.info("Kubernetes orchestrator pod started.")
# Parse / extract args.
args = parse_args()
pipeline_config = args.pipeline_config
step_command = pipeline_config["step_command"]
fixed_step_args = pipeline_config["fixed_step_args"]
step_specific_args = pipeline_config["step_specific_args"]
pipeline_dag = pipeline_config["pipeline_dag"]
# Get Kubernetes Core API for running kubectl commands later.
kube_utils.load_kube_config()
core_api = k8s_client.CoreV1Api()
# Patch run name (only needed for CRON scheduling)
run_name = patch_run_name_for_cron_scheduling(
args.run_name, fixed_step_args
)
def run_step_on_kubernetes(step_name: str) -> None:
"""Run a pipeline step in a separate Kubernetes pod.
Args:
step_name: Name of the step.
"""
# Define Kubernetes pod name.
pod_name = f"{run_name}-{step_name}"
pod_name = kube_utils.sanitize_pod_name(pod_name)
# Build list of args for this step.
step_args = [*fixed_step_args, *step_specific_args[step_name]]
# Define Kubernetes pod manifest.
pod_manifest = build_pod_manifest(
pod_name=pod_name,
run_name=run_name,
pipeline_name=args.pipeline_name,
image_name=args.image_name,
command=step_command,
args=step_args,
)
# Create and run pod.
core_api.create_namespaced_pod(
namespace=args.kubernetes_namespace,
body=pod_manifest,
)
# Wait for pod to finish.
logger.info(f"Waiting for pod of step `{step_name}` to start...")
kube_utils.wait_pod(
core_api=core_api,
pod_name=pod_name,
namespace=args.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_done,
stream_logs=True,
)
logger.info(f"Pod of step `{step_name}` completed.")
ThreadedDagRunner(dag=pipeline_dag, run_fn=run_step_on_kubernetes).run()
logger.info("Orchestration pod completed.")
parse_args()
Parse entrypoint arguments.
Returns:
Type | Description |
---|---|
Namespace |
Parsed args. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def parse_args() -> argparse.Namespace:
"""Parse entrypoint arguments.
Returns:
Parsed args.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type=str, required=True)
parser.add_argument("--pipeline_name", type=str, required=True)
parser.add_argument("--image_name", type=str, required=True)
parser.add_argument("--kubernetes_namespace", type=str, required=True)
parser.add_argument("--pipeline_config", type=json.loads, required=True)
return parser.parse_args()
patch_run_name_for_cron_scheduling(run_name, fixed_step_args)
Adjust run name according to the Kubernetes orchestrator pod name.
This is required for scheduling via CRON jobs, since each job would otherwise have the same run name, which zenml does not support.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Initial run name. |
required |
fixed_step_args |
List[str] |
Fixed entrypoint args for the step pods. We also need to patch the run name in there. |
required |
Returns:
Type | Description |
---|---|
str |
New unique run name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def patch_run_name_for_cron_scheduling(
run_name: str, fixed_step_args: List[str]
) -> str:
"""Adjust run name according to the Kubernetes orchestrator pod name.
This is required for scheduling via CRON jobs, since each job would
otherwise have the same run name, which zenml does not support.
Args:
run_name: Initial run name.
fixed_step_args: Fixed entrypoint args for the step pods.
We also need to patch the run name in there.
Returns:
New unique run name.
"""
# Get name of the orchestrator pod.
host_name = socket.gethostname()
# If we are not running as CRON job, we don't need to do anything.
if host_name == kube_utils.sanitize_pod_name(run_name):
return run_name
# Otherwise, define new run_name.
job_id = host_name.split("-")[-1]
run_name = f"{run_name}-{job_id}"
# Then also adjust run_name in fixed_step_args.
for i, arg in enumerate(fixed_step_args):
if arg == "--run_name":
fixed_step_args[i + 1] = run_name
return run_name
kubernetes_orchestrator_entrypoint_configuration
Entrypoint configuration for the Kubernetes master/orchestrator pod.
KubernetesOrchestratorEntrypointConfiguration
Entrypoint configuration for the k8s master/orchestrator pod.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
class KubernetesOrchestratorEntrypointConfiguration:
"""Entrypoint configuration for the k8s master/orchestrator pod."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all the options required for running this entrypoint.
Returns:
Entrypoint options.
"""
options = {
RUN_NAME_OPTION,
PIPELINE_NAME_OPTION,
IMAGE_NAME_OPTION,
NAMESPACE_OPTION,
PIPELINE_CONFIG_OPTION,
}
return options
@classmethod
def get_entrypoint_command(cls) -> List[str]:
"""Returns a command that runs the entrypoint module.
Returns:
Entrypoint command.
"""
command = [
"python",
"-m",
"zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
]
return command
@classmethod
def get_entrypoint_arguments(
cls,
run_name: str,
pipeline_name: str,
image_name: str,
kubernetes_namespace: str,
pb2_pipeline: Pb2Pipeline,
sorted_steps: List[BaseStep],
pipeline_dag: Dict[str, List[str]],
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
run_name: Name of the ZenML run.
pipeline_name: Name of the ZenML pipeline.
image_name: Name of the Docker image.
kubernetes_namespace: Name of the Kubernetes namespace.
pb2_pipeline: ZenML pipeline in TFX pb2 format.
sorted_steps: List of steps in execution order.
pipeline_dag: For each step, list of steps that need to run before.
Returns:
List of entrypoint arguments.
"""
def _get_step_args(step: BaseStep) -> List[str]:
"""Get the entrypoint args for a specific step.
Args:
step: ZenML step for which to get entrypoint args.
Returns:
Entrypoint args of the step.
"""
return (
KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
**{RUN_NAME_OPTION: run_name},
)
)
# Get name, command, and args for each step
step_names = [step.name for step in sorted_steps]
step_command = (
KubernetesStepEntrypointConfiguration.get_entrypoint_command()
)
fixed_step_args = []
if len(sorted_steps) > 0:
first_step_args = _get_step_args(sorted_steps[0])
fixed_step_args = split_step_args(first_step_args)[0]
step_specific_args = {
step.name: split_step_args(_get_step_args(step))[1]
for step in sorted_steps
} # e.g.: {"trainer": train_step_args, ...}
# Serialize all complex datatype args into a single JSON string
pipeline_config = {
"sorted_steps": step_names,
"step_command": step_command,
"fixed_step_args": fixed_step_args,
"step_specific_args": step_specific_args,
"pipeline_dag": pipeline_dag,
}
pipeline_config_json = json.dumps(pipeline_config)
# Define entrypoint args.
args = [
f"--{RUN_NAME_OPTION}",
run_name,
f"--{PIPELINE_NAME_OPTION}",
pipeline_name,
f"--{IMAGE_NAME_OPTION}",
image_name,
f"--{NAMESPACE_OPTION}",
kubernetes_namespace,
f"--{PIPELINE_CONFIG_OPTION}",
pipeline_config_json,
]
return args
get_entrypoint_arguments(run_name, pipeline_name, image_name, kubernetes_namespace, pb2_pipeline, sorted_steps, pipeline_dag)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the ZenML run. |
required |
pipeline_name |
str |
Name of the ZenML pipeline. |
required |
image_name |
str |
Name of the Docker image. |
required |
kubernetes_namespace |
str |
Name of the Kubernetes namespace. |
required |
pb2_pipeline |
Pipeline |
ZenML pipeline in TFX pb2 format. |
required |
sorted_steps |
List[zenml.steps.base_step.BaseStep] |
List of steps in execution order. |
required |
pipeline_dag |
Dict[str, List[str]] |
For each step, list of steps that need to run before. |
required |
Returns:
Type | Description |
---|---|
List[str] |
List of entrypoint arguments. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
cls,
run_name: str,
pipeline_name: str,
image_name: str,
kubernetes_namespace: str,
pb2_pipeline: Pb2Pipeline,
sorted_steps: List[BaseStep],
pipeline_dag: Dict[str, List[str]],
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
run_name: Name of the ZenML run.
pipeline_name: Name of the ZenML pipeline.
image_name: Name of the Docker image.
kubernetes_namespace: Name of the Kubernetes namespace.
pb2_pipeline: ZenML pipeline in TFX pb2 format.
sorted_steps: List of steps in execution order.
pipeline_dag: For each step, list of steps that need to run before.
Returns:
List of entrypoint arguments.
"""
def _get_step_args(step: BaseStep) -> List[str]:
"""Get the entrypoint args for a specific step.
Args:
step: ZenML step for which to get entrypoint args.
Returns:
Entrypoint args of the step.
"""
return (
KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
step=step,
pb2_pipeline=pb2_pipeline,
**{RUN_NAME_OPTION: run_name},
)
)
# Get name, command, and args for each step
step_names = [step.name for step in sorted_steps]
step_command = (
KubernetesStepEntrypointConfiguration.get_entrypoint_command()
)
fixed_step_args = []
if len(sorted_steps) > 0:
first_step_args = _get_step_args(sorted_steps[0])
fixed_step_args = split_step_args(first_step_args)[0]
step_specific_args = {
step.name: split_step_args(_get_step_args(step))[1]
for step in sorted_steps
} # e.g.: {"trainer": train_step_args, ...}
# Serialize all complex datatype args into a single JSON string
pipeline_config = {
"sorted_steps": step_names,
"step_command": step_command,
"fixed_step_args": fixed_step_args,
"step_specific_args": step_specific_args,
"pipeline_dag": pipeline_dag,
}
pipeline_config_json = json.dumps(pipeline_config)
# Define entrypoint args.
args = [
f"--{RUN_NAME_OPTION}",
run_name,
f"--{PIPELINE_NAME_OPTION}",
pipeline_name,
f"--{IMAGE_NAME_OPTION}",
image_name,
f"--{NAMESPACE_OPTION}",
kubernetes_namespace,
f"--{PIPELINE_CONFIG_OPTION}",
pipeline_config_json,
]
return args
get_entrypoint_command()
classmethod
Returns a command that runs the entrypoint module.
Returns:
Type | Description |
---|---|
List[str] |
Entrypoint command. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_command(cls) -> List[str]:
"""Returns a command that runs the entrypoint module.
Returns:
Entrypoint command.
"""
command = [
"python",
"-m",
"zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
]
return command
get_entrypoint_options()
classmethod
Gets all the options required for running this entrypoint.
Returns:
Type | Description |
---|---|
Set[str] |
Entrypoint options. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all the options required for running this entrypoint.
Returns:
Entrypoint options.
"""
options = {
RUN_NAME_OPTION,
PIPELINE_NAME_OPTION,
IMAGE_NAME_OPTION,
NAMESPACE_OPTION,
PIPELINE_CONFIG_OPTION,
}
return options
split_step_args(step_args)
Split step args into fixed and step-specific.
We want to have them separate so we can send the fixed args to the orchestrator pod only once.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_args |
List[str] |
list of ALL step args. E.g. ["--arg1", "arg1_value", "--arg2", "arg2_value", ...]. |
required |
Returns:
Type | Description |
---|---|
Tuple[List[str], List[str]] |
Tuple (fixed step args, step-specific args). |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
def split_step_args(step_args: List[str]) -> Tuple[List[str], List[str]]:
"""Split step args into fixed and step-specific.
We want to have them separate so we can send the fixed args to the
orchestrator pod only once.
Args:
step_args: list of ALL step args.
E.g. ["--arg1", "arg1_value", "--arg2", "arg2_value", ...].
Returns:
Tuple (fixed step args, step-specific args).
"""
fixed_args = []
step_specific_args = []
for i, arg in enumerate(step_args):
if not arg.startswith("--"): # arg is a value, not an option
continue
option_and_value = step_args[i : i + 2] # e.g. ["--name", "Aria"]
is_fixed = arg[2:] not in STEP_SPECIFIC_STEP_ENTRYPOINT_OPTIONS
if is_fixed:
fixed_args += option_and_value
else:
step_specific_args += option_and_value
return fixed_args, step_specific_args
kubernetes_step_entrypoint_configuration
Entrypoint configuration for the Kubernetes worker/step pods.
KubernetesStepEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on Kubernetes.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
class KubernetesStepEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on Kubernetes."""
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""Kubernetes specific entrypoint options.
The argument `RUN_NAME_OPTION` is needed for `get_run_name` to have
consistent values between steps.
Returns:
Set of entrypoint options.
"""
return {RUN_NAME_OPTION}
@classmethod
def get_custom_entrypoint_arguments(
cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
"""Kubernetes specific entrypoint arguments.
Sets the value for the `RUN_NAME_OPTION` argument.
Args:
step: ZenML step for which the entrypoint is built.
args: additional (unused) arguments.
kwargs: keyword args; needs to include `RUN_NAME_OPTION`.
Returns:
List of entrypoint arguments.
"""
return [
f"--{RUN_NAME_OPTION}",
kwargs[RUN_NAME_OPTION],
]
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the ZenML run name.
Args:
pipeline_name: Name of the ZenML pipeline (unused).
Returns:
ZenML run name.
"""
job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
return job_id
get_custom_entrypoint_arguments(step, *args, **kwargs)
classmethod
Kubernetes specific entrypoint arguments.
Sets the value for the RUN_NAME_OPTION
argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
ZenML step for which the entrypoint is built. |
required |
args |
Any |
additional (unused) arguments. |
() |
kwargs |
Any |
keyword args; needs to include |
{} |
Returns:
Type | Description |
---|---|
List[str] |
List of entrypoint arguments. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
"""Kubernetes specific entrypoint arguments.
Sets the value for the `RUN_NAME_OPTION` argument.
Args:
step: ZenML step for which the entrypoint is built.
args: additional (unused) arguments.
kwargs: keyword args; needs to include `RUN_NAME_OPTION`.
Returns:
List of entrypoint arguments.
"""
return [
f"--{RUN_NAME_OPTION}",
kwargs[RUN_NAME_OPTION],
]
get_custom_entrypoint_options()
classmethod
Kubernetes specific entrypoint options.
The argument RUN_NAME_OPTION
is needed for get_run_name
to have
consistent values between steps.
Returns:
Type | Description |
---|---|
Set[str] |
Set of entrypoint options. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
"""Kubernetes specific entrypoint options.
The argument `RUN_NAME_OPTION` is needed for `get_run_name` to have
consistent values between steps.
Returns:
Set of entrypoint options.
"""
return {RUN_NAME_OPTION}
get_run_name(self, pipeline_name)
Returns the ZenML run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the ZenML pipeline (unused). |
required |
Returns:
Type | Description |
---|---|
str |
ZenML run name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
"""Returns the ZenML run name.
Args:
pipeline_name: Name of the ZenML pipeline (unused).
Returns:
ZenML run name.
"""
job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
return job_id
manifest_utils
Utility functions for building manifests for k8s pods.
build_cluster_role_binding_manifest_for_service_account(name, role_name, service_account_name, namespace='default')
Build a manifest for a cluster role binding of a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the cluster role binding. |
required |
role_name |
str |
Name of the role. |
required |
service_account_name |
str |
Name of the service account. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a cluster role binding of a service account. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cluster_role_binding_manifest_for_service_account(
name: str,
role_name: str,
service_account_name: str,
namespace: str = "default",
) -> Dict[str, Any]:
"""Build a manifest for a cluster role binding of a service account.
Args:
name: Name of the cluster role binding.
role_name: Name of the role.
service_account_name: Name of the service account.
namespace: Kubernetes namespace. Defaults to "default".
Returns:
Manifest for a cluster role binding of a service account.
"""
return {
"apiVersion": "rbac.authorization.k8s.io/v1",
"kind": "ClusterRoleBinding",
"metadata": {"name": name},
"subjects": [
{
"kind": "ServiceAccount",
"name": service_account_name,
"namespace": namespace,
}
],
"roleRef": {
"kind": "ClusterRole",
"name": role_name,
"apiGroup": "rbac.authorization.k8s.io",
},
}
build_cron_job_manifest(cron_expression, pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)
Create a manifest for launching a pod as scheduled CRON job.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cron_expression |
str |
CRON job schedule expression, e.g. " * * *". |
required |
pod_name |
str |
Name of the pod. |
required |
run_name |
str |
Name of the ZenML run. |
required |
pipeline_name |
str |
Name of the ZenML pipeline. |
required |
image_name |
str |
Name of the Docker image. |
required |
command |
List[str] |
Command to execute the entrypoint in the pod. |
required |
args |
List[str] |
Arguments provided to the entrypoint command. |
required |
service_account_name |
Optional[str] |
Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster. |
None |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
CRON job manifest. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cron_job_manifest(
cron_expression: str,
pod_name: str,
run_name: str,
pipeline_name: str,
image_name: str,
command: List[str],
args: List[str],
service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Create a manifest for launching a pod as scheduled CRON job.
Args:
cron_expression: CRON job schedule expression, e.g. "* * * * *".
pod_name: Name of the pod.
run_name: Name of the ZenML run.
pipeline_name: Name of the ZenML pipeline.
image_name: Name of the Docker image.
command: Command to execute the entrypoint in the pod.
args: Arguments provided to the entrypoint command.
service_account_name: Optional name of a service account.
Can be used to assign certain roles to a pod, e.g., to allow it to
run Kubernetes commands from within the cluster.
Returns:
CRON job manifest.
"""
pod_manifest = build_pod_manifest(
pod_name=pod_name,
run_name=run_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
return {
"apiVersion": "batch/v1beta1",
"kind": "CronJob",
"metadata": pod_manifest["metadata"],
"spec": {
"schedule": cron_expression,
"jobTemplate": {
"metadata": pod_manifest["metadata"],
"spec": {"template": {"spec": pod_manifest["spec"]}},
},
},
}
build_mysql_deployment_manifest(name='mysql', namespace='default', port=3306, pv_claim_name='mysql-pv-claim')
Build a manifest for deploying a MySQL database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the deployment. Defaults to "mysql". |
'mysql' |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
port |
int |
Port where MySQL is running. Defaults to 3306. |
3306 |
pv_claim_name |
str |
Name of the required persistent volume claim.
Defaults to |
'mysql-pv-claim' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for deploying a MySQL database. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_deployment_manifest(
name: str = "mysql",
namespace: str = "default",
port: int = 3306,
pv_claim_name: str = "mysql-pv-claim",
) -> Dict[str, Any]:
"""Build a manifest for deploying a MySQL database.
Args:
name: Name of the deployment. Defaults to "mysql".
namespace: Kubernetes namespace. Defaults to "default".
port: Port where MySQL is running. Defaults to 3306.
pv_claim_name: Name of the required persistent volume claim.
Defaults to `"mysql-pv-claim"`.
Returns:
Manifest for deploying a MySQL database.
"""
return {
"apiVersion": "apps/v1",
"kind": "Deployment",
"metadata": {"name": name, "namespace": namespace},
"spec": {
"selector": {
"matchLabels": {
"app": name,
},
},
"strategy": {
"type": "Recreate",
},
"template": {
"metadata": {
"labels": {"app": name},
},
"spec": {
"containers": [
{
"image": "gcr.io/ml-pipeline/mysql:5.6",
"name": name,
"env": [
{
"name": "MYSQL_ALLOW_EMPTY_PASSWORD",
"value": '"true"',
}
],
"ports": [{"containerPort": port, "name": name}],
"volumeMounts": [
{
"name": "mysql-persistent-storage",
"mountPath": "/var/lib/mysql",
}
],
}
],
"volumes": [
{
"name": "mysql-persistent-storage",
"persistentVolumeClaim": {
"claimName": pv_claim_name
},
}
],
},
},
},
}
build_mysql_service_manifest(name='mysql', namespace='default', port=3306)
Build a manifest for a service relating to a deployed MySQL database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the service. Defaults to "mysql". |
'mysql' |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
port |
int |
Port where MySQL is running. Defaults to 3306. |
3306 |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for the MySQL service. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_service_manifest(
name: str = "mysql",
namespace: str = "default",
port: int = 3306,
) -> Dict[str, Any]:
"""Build a manifest for a service relating to a deployed MySQL database.
Args:
name: Name of the service. Defaults to "mysql".
namespace: Kubernetes namespace. Defaults to "default".
port: Port where MySQL is running. Defaults to 3306.
Returns:
Manifest for the MySQL service.
"""
return {
"apiVersion": "v1",
"kind": "Service",
"metadata": {
"name": name,
"namespace": namespace,
},
"spec": {
"selector": {"app": "mysql"},
"clusterIP": "None",
"ports": [{"port": port}],
},
}
build_namespace_manifest(namespace)
Build the manifest for a new namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
namespace |
str |
Kubernetes namespace. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest of the new namespace. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_namespace_manifest(namespace: str) -> Dict[str, Any]:
"""Build the manifest for a new namespace.
Args:
namespace: Kubernetes namespace.
Returns:
Manifest of the new namespace.
"""
return {
"apiVersion": "v1",
"kind": "Namespace",
"metadata": {
"name": namespace,
},
}
build_persistent_volume_claim_manifest(name, namespace='default', storage_request='10Gi')
Build a manifest for a persistent volume claim.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the persistent volume claim. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
storage_request |
str |
Size of the storage to request. Defaults to |
'10Gi' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a persistent volume claim. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_claim_manifest(
name: str,
namespace: str = "default",
storage_request: str = "10Gi",
) -> Dict[str, Any]:
"""Build a manifest for a persistent volume claim.
Args:
name: Name of the persistent volume claim.
namespace: Kubernetes namespace. Defaults to "default".
storage_request: Size of the storage to request. Defaults to `"10Gi"`.
Returns:
Manifest for a persistent volume claim.
"""
return {
"apiVersion": "v1",
"kind": "PersistentVolumeClaim",
"metadata": {
"name": name,
"namespace": namespace,
},
"spec": {
"storageClassName": "manual",
"accessModes": ["ReadWriteOnce"],
"resources": {
"requests": {
"storage": storage_request,
}
},
},
}
build_persistent_volume_manifest(name, namespace='default', storage_capacity='10Gi', path='/mnt/data')
Build a manifest for a persistent volume.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the persistent volume. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
storage_capacity |
str |
Storage capacity of the volume. Defaults to |
'10Gi' |
path |
str |
Path where the volume is mounted. Defaults to |
'/mnt/data' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a persistent volume. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_manifest(
name: str,
namespace: str = "default",
storage_capacity: str = "10Gi",
path: str = "/mnt/data",
) -> Dict[str, Any]:
"""Build a manifest for a persistent volume.
Args:
name: Name of the persistent volume.
namespace: Kubernetes namespace. Defaults to "default".
storage_capacity: Storage capacity of the volume. Defaults to `"10Gi"`.
path: Path where the volume is mounted. Defaults to `"/mnt/data"`.
Returns:
Manifest for a persistent volume.
"""
return {
"apiVersion": "v1",
"kind": "PersistentVolume",
"metadata": {
"name": name,
"namespace": namespace,
"labels": {"type": "local"},
},
"spec": {
"storageClassName": "manual",
"capacity": {"storage": storage_capacity},
"accessModes": ["ReadWriteOnce"],
"hostPath": {"path": path},
},
}
build_pod_manifest(pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)
Build a Kubernetes pod manifest for a ZenML run or step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod_name |
str |
Name of the pod. |
required |
run_name |
str |
Name of the ZenML run. |
required |
pipeline_name |
str |
Name of the ZenML pipeline. |
required |
image_name |
str |
Name of the Docker image. |
required |
command |
List[str] |
Command to execute the entrypoint in the pod. |
required |
args |
List[str] |
Arguments provided to the entrypoint command. |
required |
service_account_name |
Optional[str] |
Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster. |
None |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Pod manifest. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_pod_manifest(
pod_name: str,
run_name: str,
pipeline_name: str,
image_name: str,
command: List[str],
args: List[str],
service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Build a Kubernetes pod manifest for a ZenML run or step.
Args:
pod_name: Name of the pod.
run_name: Name of the ZenML run.
pipeline_name: Name of the ZenML pipeline.
image_name: Name of the Docker image.
command: Command to execute the entrypoint in the pod.
args: Arguments provided to the entrypoint command.
service_account_name: Optional name of a service account.
Can be used to assign certain roles to a pod, e.g., to allow it to
run Kubernetes commands from within the cluster.
Returns:
Pod manifest.
"""
manifest = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {
"name": pod_name,
"labels": {
"run": run_name,
"pipeline": pipeline_name,
},
},
"spec": {
"restartPolicy": "Never",
"containers": [
{
"name": "main",
"image": image_name,
"command": command,
"args": args,
"env": [
{
"name": ENV_ZENML_ENABLE_REPO_INIT_WARNINGS,
"value": "False",
}
],
}
],
},
}
if service_account_name is not None:
spec = cast(Dict[str, Any], manifest["spec"]) # mypy stupid
spec["serviceAccountName"] = service_account_name
return manifest
build_service_account_manifest(name, namespace='default')
Build the manifest for a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the service account. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a service account. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_service_account_manifest(
name: str, namespace: str = "default"
) -> Dict[str, Any]:
"""Build the manifest for a service account.
Args:
name: Name of the service account.
namespace: Kubernetes namespace. Defaults to "default".
Returns:
Manifest for a service account.
"""
return {
"apiVersion": "v1",
"metadata": {
"name": name,
"namespace": namespace,
},
}
label_studio
special
Initialization of the Label Studio integration.
LabelStudioIntegration (Integration)
Definition of Label Studio integration for ZenML.
Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
"""Definition of Label Studio integration for ZenML."""
NAME = LABEL_STUDIO
REQUIREMENTS = ["label-studio", "label-studio-sdk"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=LABEL_STUDIO_ANNOTATOR_FLAVOR,
source="zenml.integrations.label_studio.annotators.LabelStudioAnnotator",
type=StackComponentType.ANNOTATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the Label Studio integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=LABEL_STUDIO_ANNOTATOR_FLAVOR,
source="zenml.integrations.label_studio.annotators.LabelStudioAnnotator",
type=StackComponentType.ANNOTATOR,
integration=cls.NAME,
),
]
annotators
special
Initialization of the Label Studio annotators submodule.
label_studio_annotator
Implementation of the Label Studio annotation integration.
LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin)
pydantic-model
Class to interact with the Label Studio annotation interface.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The port to use for the annotation interface. |
api_key |
The API key to use for authentication. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
"""Class to interact with the Label Studio annotation interface.
Attributes:
port: The port to use for the annotation interface.
api_key: The API key to use for authentication.
"""
port: int = DEFAULT_LABEL_STUDIO_PORT
FLAVOR: ClassVar[str] = LABEL_STUDIO_ANNOTATOR_FLAVOR
@property
def validator(self) -> Optional["StackValidator"]:
"""Validates that the stack contains a cloud artifact store.
Returns:
StackValidator: Validator for the stack.
"""
def _ensure_cloud_artifact_stores(stack: Stack) -> Tuple[bool, str]:
# For now this only works on cloud artifact stores.
return (
stack.artifact_store.FLAVOR
in [
AZURE_ARTIFACT_STORE_FLAVOR,
GCP_ARTIFACT_STORE_FLAVOR,
S3_ARTIFACT_STORE_FLAVOR,
],
"Only cloud artifact stores are currently supported",
)
return StackValidator(
required_components={StackComponentType.SECRETS_MANAGER},
custom_validation_function=_ensure_cloud_artifact_stores,
)
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"http://localhost:{self.port}"
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
def _get_client(self) -> Client:
"""Gets Label Studio client.
Returns:
Label Studio client.
Raises:
ValueError: when unable to access the Label Studio API key.
"""
secret = self.get_authentication_secret(ArbitrarySecretSchema)
if not secret:
raise ValueError(
f"Unable to access predefined secret '{secret}' to access Label Studio API key."
)
api_key = secret.content["api_key"]
return Client(url=self.get_url(), api_key=api_key)
def _connection_available(self) -> bool:
"""Checks if the connection to the annotation server is available.
Returns:
True if the connection is available, False otherwise.
"""
try:
result = self._get_client().check_connection()
return result.get("status") == "UP" # type: ignore[no-any-return]
# TODO: [HIGH] refactor to use a more specific exception
except Exception:
logger.error(
"Connection error: No connection was able to be established to the Label Studio backend."
)
return False
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
NotImplementedError: If the deletion of a dataset is not supported.
"""
raise NotImplementedError("Awaiting Label Studio release.")
# TODO: Awaiting a new Label Studio version to be released with this method
# ls = self._get_client()
# dataset_name = kwargs.get("dataset_name")
# if not dataset_name:
# raise ValueError("`dataset_name` keyword argument is required.")
# dataset_id = self.get_id_from_name(dataset_name)
# if not dataset_id:
# raise ValueError(
# f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
# )
# ls.delete_project(dataset_id)
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
def register_dataset_for_annotation(
self,
config: LabelStudioDatasetRegistrationConfig,
) -> Any:
"""Registers a dataset for annotation.
Args:
config: Configuration for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(config.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=config.dataset_name,
label_config=config.label_config,
)
return dataset
def _get_azure_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Azure import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Azure import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/azure?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_gcs_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Google Cloud Storage import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Google Cloud Storage import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/gcs?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_s3_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all AWS S3 import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of AWS S3 import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/s3?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _storage_source_already_exists(
self, uri: str, config: LabelStudioDatasetSyncConfig, dataset: Project
) -> bool:
"""Returns whether a storage source already exists.
Args:
uri: URI of the storage source.
config: Configuration for the dataset.
dataset: Label Studio dataset.
Returns:
True if the storage source already exists, False otherwise.
Raises:
NotImplementedError: If the storage source type is not supported.
"""
# TODO: check we are already connected
dataset_id = int(dataset.get_params()["id"])
if config.storage_type == "azure":
storage_sources = self._get_azure_import_storage_sources(dataset_id)
elif config.storage_type == "gcs":
storage_sources = self._get_gcs_import_storage_sources(dataset_id)
elif config.storage_type == "s3":
storage_sources = self._get_s3_import_storage_sources(dataset_id)
else:
raise NotImplementedError(
f"Storage type '{config.storage_type}' not implemented."
)
return any(
(
source.get("presign") == config.presign
and source.get("bucket") == uri
and source.get("regex_filter") == config.regex_filter
and source.get("use_blob_urls") == config.use_blob_urls
and source.get("title") == dataset.get_params()["title"]
and source.get("description") == config.description
and source.get("presign_ttl") == config.presign_ttl
and source.get("project") == dataset_id
)
for source in storage_sources
)
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
def connect_and_sync_external_storage(
self,
uri: str,
config: LabelStudioDatasetSyncConfig,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
config: Configuration for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": config.prefix,
"regex_filter": config.regex_filter,
"use_blob_urls": config.use_blob_urls,
"presign": config.presign,
"presign_ttl": config.presign_ttl,
"title": dataset.get_params()["title"],
"description": config.description,
}
if config.storage_type == "azure":
if not config.azure_account_name or not config.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=config.azure_account_name,
account_key=config.azure_account_key,
**storage_connection_args,
)
elif config.storage_type == "gcs":
if not config.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=config.google_application_credentials,
**storage_connection_args,
)
elif config.storage_type == "s3":
if not config.aws_access_key_id or not config.aws_secret_access_key:
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
" Label Studio web UI as per your needs."
)
storage = dataset.connect_s3_import_storage(
bucket=uri,
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.s3_region_name,
s3_endpoint=config.s3_endpoint,
**storage_connection_args,
)
else:
raise ValueError(
f"Invalid storage type. '{config.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
@property
def root_directory(self) -> str:
"""Returns path to the root directory.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"annotators",
str(self.uuid),
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "label_studio_daemon.pid")
@property
def _log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path to the daemon log file.
"""
return os.path.join(self.root_directory, "label_studio_daemon.log")
@property
def is_provisioned(self) -> bool:
"""If the component provisioned resources to run locally.
Returns:
True if the component provisioned resources to run locally.
"""
return fileio.exists(self.root_directory)
@property
def is_running(self) -> bool:
"""If the component is running locally.
Returns:
True if the component is running locally, False otherwise.
"""
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
if not check_if_daemon_is_running(self._pid_file_path):
return False
else:
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
return True
def provision(self) -> None:
"""Spins up the annotation server backend."""
fileio.makedirs(self.root_directory)
def deprovision(self) -> None:
"""Spins down the annotation server backend."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
def resume(self) -> None:
"""Resumes the annotation interface."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
self.start_annotator_daemon()
def suspend(self) -> None:
"""Suspends the annotation interface."""
if not self.is_running:
logger.info("Local annotation server is not running.")
return
self.stop_annotator_daemon()
def start_annotator_daemon(self) -> None:
"""Starts the annotation server backend.
Raises:
ProvisioningError: If the annotation server backend is already
running or the port is already occupied.
"""
command = [
"label-studio",
"start",
"--no-browser",
"--port",
f"{self.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Label Studio server locally, "
"please run '%s' in a separate command line shell.",
self.port,
" ".join(command),
)
elif not networking_utils.port_available(self.port):
raise ProvisioningError(
f"Unable to port-forward Label Studio to local "
f"port {self.port} because the port is occupied. In order to "
f"access Label Studio locally, please "
f"change the configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Label Studio daemon (check the daemon"
"logs at `%s` in case you're not able to access the annotation "
f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
self._log_file,
)
def stop_annotator_daemon(self) -> None:
"""Stops the annotation server backend."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
is_provisioned: bool
property
readonly
If the component provisioned resources to run locally.
Returns:
Type | Description |
---|---|
bool |
True if the component provisioned resources to run locally. |
is_running: bool
property
readonly
If the component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if the component is running locally, False otherwise. |
root_directory: str
property
readonly
Returns path to the root directory.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
validator: Optional[StackValidator]
property
readonly
Validates that the stack contains a cloud artifact store.
Returns:
Type | Description |
---|---|
StackValidator |
Validator for the stack. |
add_dataset(self, **kwargs)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Exceptions:
Type | Description |
---|---|
ValueError |
if 'dataset_name' and 'label_config' aren't provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
connect_and_sync_external_storage(self, uri, config, dataset)
Syncs the external storage for the given project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI of the storage source. |
required |
config |
LabelStudioDatasetSyncConfig |
Configuration for the dataset. |
required |
dataset |
Project |
Label Studio dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[Dict[str, Any]] |
A dictionary containing the sync result. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the storage type is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
self,
uri: str,
config: LabelStudioDatasetSyncConfig,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
config: Configuration for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": config.prefix,
"regex_filter": config.regex_filter,
"use_blob_urls": config.use_blob_urls,
"presign": config.presign,
"presign_ttl": config.presign_ttl,
"title": dataset.get_params()["title"],
"description": config.description,
}
if config.storage_type == "azure":
if not config.azure_account_name or not config.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=config.azure_account_name,
account_key=config.azure_account_key,
**storage_connection_args,
)
elif config.storage_type == "gcs":
if not config.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=config.google_application_credentials,
**storage_connection_args,
)
elif config.storage_type == "s3":
if not config.aws_access_key_id or not config.aws_secret_access_key:
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
" Label Studio web UI as per your needs."
)
storage = dataset.connect_s3_import_storage(
bucket=uri,
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.s3_region_name,
s3_endpoint=config.s3_endpoint,
**storage_connection_args,
)
else:
raise ValueError(
f"Invalid storage type. '{config.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)
Deletes a dataset from the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
If the deletion of a dataset is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
NotImplementedError: If the deletion of a dataset is not supported.
"""
raise NotImplementedError("Awaiting Label Studio release.")
# TODO: Awaiting a new Label Studio version to be released with this method
# ls = self._get_client()
# dataset_name = kwargs.get("dataset_name")
# if not dataset_name:
# raise ValueError("`dataset_name` keyword argument is required.")
# dataset_id = self.get_id_from_name(dataset_name)
# if not dataset_id:
# raise ValueError(
# f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
# )
# ls.delete_project(dataset_id)
deprovision(self)
Spins down the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def deprovision(self) -> None:
"""Spins down the annotation server backend."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
get_converted_dataset(self, dataset_name, output_format)
Extract annotated tasks in a specific converted format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Id of the dataset. |
required |
output_format |
str |
Output format. |
required |
Returns:
Type | Description |
---|---|
Dict[Any, Any] |
A dictionary containing the converted dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
get_dataset(self, **kwargs)
Gets the dataset with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The LabelStudio Dataset object (a 'Project') for the given name. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
get_dataset_names(self)
Gets the names of the datasets.
Returns:
Type | Description |
---|---|
List[str] |
A list of dataset names. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
get_dataset_stats(self, dataset_name)
Gets the statistics of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Tuple[int, int] |
A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset. |
Exceptions:
Type | Description |
---|---|
IndexError |
If the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
get_datasets(self)
Gets the datasets currently available for annotation.
Returns:
Type | Description |
---|---|
List[Any] |
A list of datasets. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)
Gets the ID of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[int] |
The ID of the dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
get_labeled_data(self, **kwargs)
Gets the labeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The labeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)
Returns the parsed Label Studio label config for a dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_id |
int |
Id of the dataset. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the parsed label config. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no dataset is found for the given id. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)
Gets the unlabeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
str |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The unlabeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)
Gets the top-level URL of the annotation interface.
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"http://localhost:{self.port}"
get_url_for_dataset(self, dataset_name)
Gets the URL of the annotation interface for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)
Launches the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
Optional[str] |
The URL of the annotation interface. |
required |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
provision(self)
Spins up the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def provision(self) -> None:
"""Spins up the annotation server backend."""
fileio.makedirs(self.root_directory)
register_dataset_for_annotation(self, config)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LabelStudioDatasetRegistrationConfig |
Configuration for the dataset. |
required |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
self,
config: LabelStudioDatasetRegistrationConfig,
) -> Any:
"""Registers a dataset for annotation.
Args:
config: Configuration for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(config.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=config.dataset_name,
label_config=config.label_config,
)
return dataset
resume(self)
Resumes the annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def resume(self) -> None:
"""Resumes the annotation interface."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
self.start_annotator_daemon()
start_annotator_daemon(self)
Starts the annotation server backend.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the annotation server backend is already running or the port is already occupied. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def start_annotator_daemon(self) -> None:
"""Starts the annotation server backend.
Raises:
ProvisioningError: If the annotation server backend is already
running or the port is already occupied.
"""
command = [
"label-studio",
"start",
"--no-browser",
"--port",
f"{self.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Label Studio server locally, "
"please run '%s' in a separate command line shell.",
self.port,
" ".join(command),
)
elif not networking_utils.port_available(self.port):
raise ProvisioningError(
f"Unable to port-forward Label Studio to local "
f"port {self.port} because the port is occupied. In order to "
f"access Label Studio locally, please "
f"change the configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Label Studio daemon (check the daemon"
"logs at `%s` in case you're not able to access the annotation "
f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
self._log_file,
)
stop_annotator_daemon(self)
Stops the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def stop_annotator_daemon(self) -> None:
"""Stops the annotation server backend."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
suspend(self)
Suspends the annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def suspend(self) -> None:
"""Suspends the annotation interface."""
if not self.is_running:
logger.info("Local annotation server is not running.")
return
self.stop_annotator_daemon()
label_config_generators
special
Initialization of the Label Studio config generators submodule.
label_config_generators
Implementation of label config generators for Label Studio.
generate_basic_object_detection_bounding_boxes_label_config(labels)
Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at
https://labelstud.io/templates/image_bbox.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES
label_config_start = """<View>
<Image name="image" value="$image"/>
<RectangleLabels name="label" toName="image">
"""
label_config_choices = "".join(
f"<Label value='{label}' />\n" for label in labels
)
label_config_end = "</RectangleLabels>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
generate_image_classification_label_config(labels)
Generates a Label Studio label config for image classification.
This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio label config for image classification.
This is based on the basic config example shown at
https://labelstud.io/templates/image_classification.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION
label_config_start = """<View>
<Image name="image" value="$image"/>
<Choices name="choice" toName="image">
"""
label_config_choices = "".join(
f"<Choice value='{label}' />\n" for label in labels
)
label_config_end = "</Choices>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
label_studio_utils
Utility functions for the Label Studio annotator integration.
convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)
Converts a list of predictions from local file references to task id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
List[Dict[str, Any]] |
List of predictions. |
required |
tasks |
List[Dict[str, Any]] |
List of tasks. |
required |
filename_reference |
str |
Name of the file reference in the predictions. |
required |
storage_type |
str |
Storage type of the predictions. |
required |
Returns:
Type | Description |
---|---|
List[Dict[str, Any]] |
List of predictions using task ids as reference. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
preds: List[Dict[str, Any]],
tasks: List[Dict[str, Any]],
filename_reference: str,
storage_type: str,
) -> List[Dict[str, Any]]:
"""Converts a list of predictions from local file references to task id.
Args:
preds: List of predictions.
tasks: List of tasks.
filename_reference: Name of the file reference in the predictions.
storage_type: Storage type of the predictions.
Returns:
List of predictions using task ids as reference.
"""
filename_id_mapping = {
os.path.basename(urlparse(task["data"][filename_reference]).path): task[
"id"
]
for task in tasks
}
# GCS and S3 URL encodes filenames containing spaces, requiring this
# separate encoding step
if storage_type in {"gcs", "s3"}:
preds = [
{"filename": quote(pred["filename"]), "result": pred["result"]}
for pred in preds
]
return [
{
"task": int(
filename_id_mapping[os.path.basename(pred["filename"])]
),
"result": pred["result"],
}
for pred in preds
]
get_file_extension(path_str)
Return the file extension of the given filename.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path_str |
str |
Path to the file. |
required |
Returns:
Type | Description |
---|---|
str |
File extension. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
"""Return the file extension of the given filename.
Args:
path_str: Path to the file.
Returns:
File extension.
"""
return os.path.splitext(urlparse(path_str).path)[1]
is_azure_url(url)
Return whether the given URL is an Azure URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an Azure URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
"""Return whether the given URL is an Azure URL.
Args:
url: URL to check.
Returns:
True if the URL is an Azure URL, False otherwise.
"""
return "blob.core.windows.net" in urlparse(url).netloc
is_gcs_url(url)
Return whether the given URL is an GCS URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an GCS URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
"""Return whether the given URL is an GCS URL.
Args:
url: URL to check.
Returns:
True if the URL is an GCS URL, False otherwise.
"""
return "storage.googleapis.com" in urlparse(url).netloc
is_s3_url(url)
Return whether the given URL is an S3 URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an S3 URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
"""Return whether the given URL is an S3 URL.
Args:
url: URL to check.
Returns:
True if the URL is an S3 URL, False otherwise.
"""
return "s3.amazonaws" in urlparse(url).netloc
steps
special
Standard steps to be used with the Label Studio annotator integration.
label_studio_standard_steps
Implementation of standard steps for the Label Studio annotator integration.
LabelStudioDatasetRegistrationConfig (BaseStepConfig)
pydantic-model
Step config when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationConfig(BaseStepConfig):
"""Step config when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
LabelStudioDatasetSyncConfig (BaseStepConfig)
pydantic-model
Step config when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncConfig(BaseStepConfig):
"""Step config when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to.
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
get_labeled_data (BaseStep)
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
Name of the dataset. |
required | |
context |
The StepContext. |
required |
Returns:
Type | Description |
---|---|
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
entrypoint(dataset_name, context)
staticmethod
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Name of the dataset. |
required |
context |
StepContext |
The StepContext. |
required |
Returns:
Type | Description |
---|---|
List |
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str, context: StepContext) -> List: # type: ignore[type-arg]
"""Gets labeled data from the dataset.
Args:
dataset_name: Name of the dataset.
context: The StepContext.
Returns:
List of labeled data.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
# TODO [MEDIUM]: have this check for new data *since the last time this step ran*
annotator = context.stack.annotator # type: ignore[union-attr]
if not annotator:
raise StackComponentInterfaceError("No active annotator.")
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator._connection_available():
dataset = annotator.get_dataset(dataset_name=dataset_name)
return dataset.get_labeled_tasks() # type: ignore[no-any-return]
raise StackComponentInterfaceError(
"Unable to connect to annotator stack component."
)
get_or_create_dataset (BaseStep)
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Step config. |
required | |
context |
Step context. |
required |
Returns:
Type | Description |
---|---|
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Step config when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationConfig(BaseStepConfig):
"""Step config when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
entrypoint(config, context)
staticmethod
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LabelStudioDatasetRegistrationConfig |
Step config. |
required |
context |
StepContext |
Step context. |
required |
Returns:
Type | Description |
---|---|
str |
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
config: LabelStudioDatasetRegistrationConfig,
context: StepContext,
) -> str:
"""Gets preexisting dataset or creates a new one.
Args:
config: Step config.
context: Step context.
Returns:
The dataset name.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = context.stack.annotator # type: ignore[union-attr]
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator and annotator._connection_available():
for dataset in annotator.get_datasets():
if dataset.get_params()["title"] == config.dataset_name:
return cast(str, dataset.get_params()["title"])
dataset = annotator.register_dataset_for_annotation(config)
return cast(str, dataset.get_params()["title"])
raise StackComponentInterfaceError("No active annotator.")
# if annotator and annotator._connection_available():
# preexisting_dataset_list = [
# dataset
# for dataset in annotator.get_datasets()
# if dataset.get_params()["title"] == config.dataset_name
# ]
# if (
# not preexisting_dataset_list
# and annotator
# and annotator._connection_available()
# ):
# registered_dataset = annotator.register_dataset_for_annotation(
# config
# )
# elif preexisting_dataset_list:
# return cast(str, preexisting_dataset_list[0].get_params()["title"])
# else:
# raise StackComponentInterfaceError("No active annotator.")
# return cast(str, registered_dataset.get_params()["title"])
# else:
# raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (BaseStep)
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
The URI of the data to sync. |
required | |
dataset_name |
The name of the dataset to sync to. |
required | |
predictions |
The predictions to sync. |
required | |
config |
The config for the sync. |
required | |
context |
The StepContext. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Step config when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncConfig(BaseStepConfig):
"""Step config when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to.
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
entrypoint(uri, dataset_name, predictions, config, context)
staticmethod
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI of the data to sync. |
required |
dataset_name |
str |
The name of the dataset to sync to. |
required |
predictions |
List[Dict[str, Any]] |
The predictions to sync. |
required |
config |
LabelStudioDatasetSyncConfig |
The config for the sync. |
required |
context |
StepContext |
The StepContext. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
uri: str,
dataset_name: str,
predictions: List[Dict[str, Any]],
config: LabelStudioDatasetSyncConfig,
context: StepContext,
) -> None:
"""Syncs new data to Label Studio.
Args:
uri: The URI of the data to sync.
dataset_name: The name of the dataset to sync to.
predictions: The predictions to sync.
config: The config for the sync.
context: The StepContext.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
ValueError: if you are trying to sync from outside ZenML.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = context.stack.annotator # type: ignore[union-attr]
artifact_store = context.stack.artifact_store # type: ignore[union-attr]
secrets_manager = context.stack.secrets_manager # type: ignore[union-attr]
if not annotator or not artifact_store or not secrets_manager:
raise StackComponentInterfaceError(
"An active annotator, artifact store and secrets manager are required to run this step."
)
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
# TODO: check that annotator is connected before querying it
dataset = annotator.get_dataset(dataset_name=dataset_name)
if not uri.startswith(artifact_store.path):
raise ValueError(
"ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
)
# removes the initial forward slash from the prefix attribute by slicing
config.prefix = urlparse(uri).path.lstrip("/")
base_uri = urlparse(uri).netloc
# gets the secret used for authentication
authentication_secret_name = artifact_store.authentication_secret # type: ignore[union-attr]
if config.storage_type == "azure":
config.azure_account_name = secrets_manager.get_secret( # type: ignore[union-attr]
authentication_secret_name
).account_name
config.azure_account_key = secrets_manager.get_secret( # type: ignore[union-attr]
authentication_secret_name
).account_key
elif config.storage_type == "gcs":
config.google_application_credentials = secrets_manager.get_secret( # type: ignore[union-attr]
authentication_secret_name
).token
elif config.storage_type == "s3":
config.aws_access_key_id = secrets_manager.get_secret( # type: ignore[union-attr]
LABEL_STUDIO_AWS_SECRET_NAME
).aws_access_key_id
config.aws_secret_access_key = secrets_manager.get_secret( # type: ignore[union-attr]
LABEL_STUDIO_AWS_SECRET_NAME
).aws_secret_access_key
config.aws_session_token = secrets_manager.get_secret( # type: ignore[union-attr]
LABEL_STUDIO_AWS_SECRET_NAME
).aws_session_token
if annotator and annotator._connection_available():
# TODO: get existing (CHECK!) or create the sync connection
annotator.connect_and_sync_external_storage(
uri=base_uri,
config=config,
dataset=dataset,
)
if predictions:
filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
config.label_config_type
]
preds_with_task_ids = convert_pred_filenames_to_task_ids(
predictions,
dataset.tasks,
filename_reference,
config.storage_type,
)
# TODO: filter out any predictions that exist + have already been
# made (maybe?). Only pass in preds for tasks without pre-annotations.
dataset.create_predictions(preds_with_task_ids)
else:
raise StackComponentInterfaceError("No active annotator.")
lightgbm
special
Initialization of the LightGBM integration.
LightGBMIntegration (Integration)
Definition of lightgbm integration for ZenML.
Source code in zenml/integrations/lightgbm/__init__.py
class LightGBMIntegration(Integration):
"""Definition of lightgbm integration for ZenML."""
NAME = LIGHTGBM
REQUIREMENTS = ["lightgbm>=1.0.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.lightgbm import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/lightgbm/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.lightgbm import materializers # noqa
materializers
special
Initialization of the Neural Prophet materializer.
lightgbm_booster_materializer
Implementation of the LightGBM booster materializer.
LightGBMBoosterMaterializer (BaseMaterializer)
Materializer to read data to and from lightgbm.Booster.
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
class LightGBMBoosterMaterializer(BaseMaterializer):
"""Materializer to read data to and from lightgbm.Booster."""
ASSOCIATED_TYPES = (lgb.Booster,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
"""Reads a lightgbm Booster model from a serialized JSON file.
Args:
data_type: A lightgbm Booster type.
Returns:
A lightgbm Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = lgb.Booster(model_file=temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
def handle_return(self, booster: lgb.Booster) -> None:
"""Creates a JSON serialization for a lightgbm Booster model.
Args:
booster: A lightgbm Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
handle_input(self, data_type)
Reads a lightgbm Booster model from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A lightgbm Booster type. |
required |
Returns:
Type | Description |
---|---|
Booster |
A lightgbm Booster object. |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
"""Reads a lightgbm Booster model from a serialized JSON file.
Args:
data_type: A lightgbm Booster type.
Returns:
A lightgbm Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = lgb.Booster(model_file=temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
handle_return(self, booster)
Creates a JSON serialization for a lightgbm Booster model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
booster |
Booster |
A lightgbm Booster model. |
required |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_return(self, booster: lgb.Booster) -> None:
"""Creates a JSON serialization for a lightgbm Booster model.
Args:
booster: A lightgbm Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
lightgbm_dataset_materializer
Implementation of the LightGBM materializer.
LightGBMDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from lightgbm.Dataset.
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
class LightGBMDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from lightgbm.Dataset."""
ASSOCIATED_TYPES = (lgb.Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
"""Reads a lightgbm.Dataset binary file and loads it.
Args:
data_type: A lightgbm.Dataset type.
Returns:
A lightgbm.Dataset object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = lgb.Dataset(temp_file, free_raw_data=False)
# No clean up this time because matrix is lazy loaded
return matrix
def handle_return(self, matrix: lgb.Dataset) -> None:
"""Creates a binary serialization for a lightgbm.Dataset object.
Args:
matrix: A lightgbm.Dataset object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
matrix.save_binary(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
handle_input(self, data_type)
Reads a lightgbm.Dataset binary file and loads it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A lightgbm.Dataset type. |
required |
Returns:
Type | Description |
---|---|
Dataset |
A lightgbm.Dataset object. |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
"""Reads a lightgbm.Dataset binary file and loads it.
Args:
data_type: A lightgbm.Dataset type.
Returns:
A lightgbm.Dataset object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = lgb.Dataset(temp_file, free_raw_data=False)
# No clean up this time because matrix is lazy loaded
return matrix
handle_return(self, matrix)
Creates a binary serialization for a lightgbm.Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
matrix |
Dataset |
A lightgbm.Dataset object. |
required |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_return(self, matrix: lgb.Dataset) -> None:
"""Creates a binary serialization for a lightgbm.Dataset object.
Args:
matrix: A lightgbm.Dataset object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
matrix.save_binary(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
mlflow
special
Initialization for the ZenML MLflow integration.
The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.
MlflowIntegration (Integration)
Definition of MLflow integration for ZenML.
Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
"""Definition of MLflow integration for ZenML."""
NAME = MLFLOW
REQUIREMENTS = [
"mlflow>=1.2.0,<1.26.0",
"mlserver>=0.5.3",
"mlserver-mlflow>=0.5.3",
]
@classmethod
def activate(cls) -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the MLflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=MLFLOW_MODEL_DEPLOYER_FLAVOR,
source="zenml.integrations.mlflow.model_deployers.MLFlowModelDeployer",
type=StackComponentType.MODEL_DEPLOYER,
integration=cls.NAME,
),
FlavorWrapper(
name=MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR,
source="zenml.integrations.mlflow.experiment_trackers.MLFlowExperimentTracker",
type=StackComponentType.EXPERIMENT_TRACKER,
integration=cls.NAME,
),
]
activate()
classmethod
Activate the MLflow integration.
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
flavors()
classmethod
Declare the stack component flavors for the MLflow integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the MLflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=MLFLOW_MODEL_DEPLOYER_FLAVOR,
source="zenml.integrations.mlflow.model_deployers.MLFlowModelDeployer",
type=StackComponentType.MODEL_DEPLOYER,
integration=cls.NAME,
),
FlavorWrapper(
name=MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR,
source="zenml.integrations.mlflow.experiment_trackers.MLFlowExperimentTracker",
type=StackComponentType.EXPERIMENT_TRACKER,
integration=cls.NAME,
),
]
experiment_trackers
special
Initialization of the MLflow experiment tracker.
mlflow_experiment_tracker
Implementation of the MLflow experiment tracker for ZenML.
MLFlowExperimentTracker (BaseExperimentTracker)
pydantic-model
Stores Mlflow configuration options.
ZenML should take care of configuring MLflow for you, but should you still need access to the configuration inside your step you can do it using a step context:
from zenml.steps import StepContext
@enable_mlflow
@step
def my_step(context: StepContext, ...)
context.stack.experiment_tracker # get the tracking_uri etc. from here
Attributes:
Name | Type | Description |
---|---|---|
tracking_uri |
Optional[str] |
The uri of the mlflow tracking server. If no uri is set,
your stack must contain a |
tracking_username |
Optional[str] |
Username for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_password |
Optional[str] |
Password for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_token |
Optional[str] |
Token for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_insecure_tls |
bool |
Skips verification of TLS connection to the
MLflow tracking server if set to |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
"""Stores Mlflow configuration options.
ZenML should take care of configuring MLflow for you, but should you still
need access to the configuration inside your step you can do it using a
step context:
```python
from zenml.steps import StepContext
@enable_mlflow
@step
def my_step(context: StepContext, ...)
context.stack.experiment_tracker # get the tracking_uri etc. from here
```
Attributes:
tracking_uri: The uri of the mlflow tracking server. If no uri is set,
your stack must contain a `LocalArtifactStore` and ZenML will
point MLflow to a subdirectory of your artifact store instead.
tracking_username: Username for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_password: Password for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_token: Token for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_insecure_tls: Skips verification of TLS connection to the
MLflow tracking server if set to `True`.
"""
tracking_uri: Optional[str] = None
tracking_username: Optional[str] = None
tracking_password: Optional[str] = None
tracking_token: Optional[str] = None
tracking_insecure_tls: bool = False
# Class Configuration
FLAVOR: ClassVar[str] = MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR
@validator("tracking_uri")
def _ensure_valid_tracking_uri(
cls, tracking_uri: Optional[str] = None
) -> Optional[str]:
"""Ensures that the tracking uri is a valid mlflow tracking uri.
Args:
tracking_uri: The tracking uri to validate.
Returns:
The tracking uri if it is valid.
Raises:
ValueError: If the tracking uri is not valid.
"""
if tracking_uri:
valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
if not any(
tracking_uri.startswith(scheme) for scheme in valid_schemes
):
raise ValueError(
f"MLflow tracking uri does not start with one of the valid "
f"schemes {valid_schemes}. See "
f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
f"for more information."
)
return tracking_uri
@root_validator(skip_on_failure=True)
def _ensure_authentication_if_necessary(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Ensures that credentials or a token for authentication exist.
We make this check when running MLflow tracking with a remote backend.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If neither credentials nor a token are provided.
"""
tracking_uri = values.get("tracking_uri")
if tracking_uri and cls.is_remote_tracking_uri(tracking_uri):
# we need either username + password or a token to authenticate to
# the remote backend
basic_auth = values.get("tracking_username") and values.get(
"tracking_password"
)
token_auth = values.get("tracking_token")
if not (basic_auth or token_auth):
raise ValueError(
f"MLflow experiment tracking with a remote backend "
f"{tracking_uri} is only possible when specifying either "
f"username and password or an authentication token in your "
f"stack component. To update your component, run the "
f"following command: `zenml experiment-tracker update "
f"{values['name']} --tracking_username=MY_USERNAME "
f"--tracking_password=MY_PASSWORD "
f"--tracking_token=MY_TOKEN` and specify either your "
f"username and password or token."
)
return values
@staticmethod
def is_remote_tracking_uri(tracking_uri: str) -> bool:
"""Checks whether the given tracking uri is remote or not.
Args:
tracking_uri: The tracking uri to check.
Returns:
`True` if the tracking uri is remote, `False` otherwise.
"""
return any(
tracking_uri.startswith(prefix)
for prefix in ["http://", "https://"]
)
@staticmethod
def _local_mlflow_backend() -> str:
"""Gets the local MLflow backend inside the ZenML artifact repository directory.
Returns:
The MLflow tracking URI for the local MLflow backend.
"""
repo = Repository(skip_repository_check=True) # type: ignore[call-arg]
artifact_store = repo.active_stack.artifact_store
local_mlflow_backend_uri = os.path.join(artifact_store.path, "mlruns")
if not os.path.exists(local_mlflow_backend_uri):
os.makedirs(local_mlflow_backend_uri)
return "file:" + local_mlflow_backend_uri
def get_tracking_uri(self) -> str:
"""Returns the configured tracking URI or a local fallback.
Returns:
The tracking URI.
"""
return self.tracking_uri or self._local_mlflow_backend()
def configure_mlflow(self) -> None:
"""Configures the MLflow tracking URI and any additional credentials."""
mlflow.set_tracking_uri(self.get_tracking_uri())
if self.tracking_username:
os.environ[MLFLOW_TRACKING_USERNAME] = self.tracking_username
if self.tracking_password:
os.environ[MLFLOW_TRACKING_PASSWORD] = self.tracking_password
if self.tracking_token:
os.environ[MLFLOW_TRACKING_TOKEN] = self.tracking_token
os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
"true" if self.tracking_insecure_tls else "false"
)
def prepare_step_run(self) -> None:
"""Sets the MLflow tracking uri and credentials."""
self.configure_mlflow()
def cleanup_step_run(self) -> None:
"""Resets the MLflow tracking uri."""
mlflow.set_tracking_uri("")
@property
def local_path(self) -> Optional[str]:
"""Path to the local directory where the MLflow artifacts are stored.
Returns:
None if configured with a remote tracking URI, otherwise the
path to the local MLflow artifact store directory.
"""
tracking_uri = self.get_tracking_uri()
if self.is_remote_tracking_uri(tracking_uri):
return None
else:
assert tracking_uri.startswith("file:")
return tracking_uri[5:]
@property
def validator(self) -> Optional["StackValidator"]:
"""Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.
Returns:
An optional `StackValidator`.
"""
if self.tracking_uri:
# user specified a tracking uri, do nothing
return None
else:
# try to fall back to a tracking uri inside the zenml artifact
# store. this only works in case of a local artifact store, so we
# make sure to prevent stack with other artifact stores for now
return StackValidator(
custom_validation_function=lambda stack: (
isinstance(stack.artifact_store, LocalArtifactStore),
"MLflow experiment tracker without a specified tracking "
"uri only works with a local artifact store.",
)
)
@property
def active_experiment(self) -> Optional[Experiment]:
"""Returns the currently active MLflow experiment.
Returns:
The active experiment or `None` if no experiment is active.
"""
step_env = Environment().step_environment
if not step_env:
# we're not inside a step
return None
mlflow.set_experiment(experiment_name=step_env.pipeline_name)
return mlflow.get_experiment_by_name(step_env.pipeline_name)
def _find_active_run(
self,
) -> Tuple[Optional[mlflow.ActiveRun], Optional[str], Optional[str]]:
"""Find the currently active MLflow run.
Returns:
The active MLflow run, the experiment id and the run id
"""
step_env = Environment().step_environment
if not self.active_experiment or not step_env:
return None, None, None
experiment_id = self.active_experiment.experiment_id
# TODO [ENG-458]: find a solution to avoid race-conditions while
# creating the same MLflow run from parallel steps
runs = mlflow.search_runs(
experiment_ids=[experiment_id],
filter_string=f'tags.mlflow.runName = "{step_env.pipeline_run_id}"',
output_format="list",
)
run_id = runs[0].info.run_id if runs else None
current_active_run = mlflow.active_run()
if not (
current_active_run and current_active_run.info.run_id == run_id
):
current_active_run = None
return current_active_run, experiment_id, run_id
@property
def active_run(self) -> Optional[mlflow.ActiveRun]:
"""Returns the currently active MLflow run.
Returns:
The active MLflow run.
"""
step_env = Environment().step_environment
current_active_run, experiment_id, run_id = self._find_active_run()
if current_active_run:
return current_active_run
else:
return mlflow.start_run(
run_id=run_id,
run_name=step_env.pipeline_run_id,
experiment_id=experiment_id,
)
@property
def active_nested_run(self) -> Optional[mlflow.ActiveRun]:
"""Returns a nested run in the currently active MLflow run.
Returns:
The nested MLflow run.
"""
step_env = Environment().step_environment
current_active_run, _, _ = self._find_active_run()
if current_active_run:
return mlflow.start_run(run_name=step_env.step_name, nested=True)
else:
# Return None
return current_active_run
active_experiment: Optional[mlflow.entities.experiment.Experiment]
property
readonly
Returns the currently active MLflow experiment.
Returns:
Type | Description |
---|---|
Optional[mlflow.entities.experiment.Experiment] |
The active experiment or |
active_nested_run: Optional[mlflow.tracking.fluent.ActiveRun]
property
readonly
Returns a nested run in the currently active MLflow run.
Returns:
Type | Description |
---|---|
Optional[mlflow.tracking.fluent.ActiveRun] |
The nested MLflow run. |
active_run: Optional[mlflow.tracking.fluent.ActiveRun]
property
readonly
Returns the currently active MLflow run.
Returns:
Type | Description |
---|---|
Optional[mlflow.tracking.fluent.ActiveRun] |
The active MLflow run. |
local_path: Optional[str]
property
readonly
Path to the local directory where the MLflow artifacts are stored.
Returns:
Type | Description |
---|---|
Optional[str] |
None if configured with a remote tracking URI, otherwise the path to the local MLflow artifact store directory. |
validator: Optional[StackValidator]
property
readonly
Checks the stack has a LocalArtifactStore
if no tracking uri was specified.
Returns:
Type | Description |
---|---|
Optional[StackValidator] |
An optional |
cleanup_step_run(self)
Resets the MLflow tracking uri.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(self) -> None:
"""Resets the MLflow tracking uri."""
mlflow.set_tracking_uri("")
configure_mlflow(self)
Configures the MLflow tracking URI and any additional credentials.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
"""Configures the MLflow tracking URI and any additional credentials."""
mlflow.set_tracking_uri(self.get_tracking_uri())
if self.tracking_username:
os.environ[MLFLOW_TRACKING_USERNAME] = self.tracking_username
if self.tracking_password:
os.environ[MLFLOW_TRACKING_PASSWORD] = self.tracking_password
if self.tracking_token:
os.environ[MLFLOW_TRACKING_TOKEN] = self.tracking_token
os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
"true" if self.tracking_insecure_tls else "false"
)
get_tracking_uri(self)
Returns the configured tracking URI or a local fallback.
Returns:
Type | Description |
---|---|
str |
The tracking URI. |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self) -> str:
"""Returns the configured tracking URI or a local fallback.
Returns:
The tracking URI.
"""
return self.tracking_uri or self._local_mlflow_backend()
is_remote_tracking_uri(tracking_uri)
staticmethod
Checks whether the given tracking uri is remote or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tracking_uri |
str |
The tracking uri to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
@staticmethod
def is_remote_tracking_uri(tracking_uri: str) -> bool:
"""Checks whether the given tracking uri is remote or not.
Args:
tracking_uri: The tracking uri to check.
Returns:
`True` if the tracking uri is remote, `False` otherwise.
"""
return any(
tracking_uri.startswith(prefix)
for prefix in ["http://", "https://"]
)
prepare_step_run(self)
Sets the MLflow tracking uri and credentials.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self) -> None:
"""Sets the MLflow tracking uri and credentials."""
self.configure_mlflow()
mlflow_step_decorator
Implementation of the MLflow StepDecorator.
enable_mlflow(_step=None, nested=False)
Decorator to enable mlflow for a step function.
Apply this decorator to a ZenML pipeline step to enable MLflow experiment
tracking. The MLflow tracking configuration (tracking URI, experiment name,
run name) will be automatically configured before the step code is executed,
so the step can simply use the mlflow
module to log metrics and artifacts.
The simple usage will log metrics into a run created for the pipeline, like so:
@enable_mlflow
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
mlflow.log_metric("val_accuracy", test_acc)
return test_acc
You can also log parameters, metrics and artifacts into nested runs, which
will be children of the pipeline run. You only need to add the parameter
nested=True
to the decorator, like so:
@enable_mlflow(nested=True)
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
mlflow.log_param("some_param", 2)
mlflow.log_metric("val_accuracy", test_acc)
return test_acc
You can also use this decorator with our class-based API like so:
@enable_mlflow
class TFEvaluator(BaseStep):
def entrypoint(
self,
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
...
All MLflow artifacts and metrics logged from all the steps in a pipeline
run are by default grouped under a single experiment named after the
pipeline. To log MLflow artifacts and metrics from a step in a separate
MLflow experiment, pass a custom experiment_name
argument value to the
decorator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_step |
Optional[~S] |
The decorated step class. |
None |
nested |
bool |
Controls whether to create a run as a child of pipeline run.
All the the mlflow logging functions using during a step with
|
False |
Returns:
Type | Description |
---|---|
Union[~S, Callable[[~S], ~S]] |
The inner decorator which enhances the input step class with mlflow tracking functionality |
Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def enable_mlflow(
_step: Optional[S] = None, nested: bool = False
) -> Union[S, Callable[[S], S]]:
"""Decorator to enable mlflow for a step function.
Apply this decorator to a ZenML pipeline step to enable MLflow experiment
tracking. The MLflow tracking configuration (tracking URI, experiment name,
run name) will be automatically configured before the step code is executed,
so the step can simply use the `mlflow` module to log metrics and artifacts.
The simple usage will log metrics into a run created for the pipeline, like
so:
```python
@enable_mlflow
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
mlflow.log_metric("val_accuracy", test_acc)
return test_acc
```
You can also log parameters, metrics and artifacts into nested runs, which
will be children of the pipeline run. You only need to add the parameter
`nested=True` to the decorator, like so:
```python
@enable_mlflow(nested=True)
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
mlflow.log_param("some_param", 2)
mlflow.log_metric("val_accuracy", test_acc)
return test_acc
```
You can also use this decorator with our class-based API like so:
```
@enable_mlflow
class TFEvaluator(BaseStep):
def entrypoint(
self,
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
...
```
All MLflow artifacts and metrics logged from all the steps in a pipeline
run are by default grouped under a single experiment named after the
pipeline. To log MLflow artifacts and metrics from a step in a separate
MLflow experiment, pass a custom `experiment_name` argument value to the
decorator.
Args:
_step: The decorated step class.
nested: Controls whether to create a run as a child of pipeline run.
All the the mlflow logging functions using during a step with
`nested=True` will be logged into the child run.
Returns:
The inner decorator which enhances the input step class with mlflow
tracking functionality
"""
def inner_decorator(_step: S) -> S:
logger.debug(
"Applying 'enable_mlflow' decorator to step %s", _step.__name__
)
if not issubclass(_step, BaseStep):
raise RuntimeError(
"The `enable_mlflow` decorator can only be applied to a ZenML "
"`step` decorated function or a BaseStep subclass."
)
source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
new_entrypoint = mlflow_step_entrypoint(nested=nested)(source_fn)
if _step._created_by_functional_api():
# If the step was created by the functional API, the old entrypoint
# was a static method -> make sure the new one is as well
new_entrypoint = staticmethod(new_entrypoint)
setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
return _step
if _step is None:
return inner_decorator
else:
return inner_decorator(_step)
mlflow_step_entrypoint(nested=False)
Decorator for a step entrypoint to enable mlflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nested |
bool |
Controls whether to create a run as a child of pipeline run.
All the the mlflow logging functions using during a step with
|
False |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
the input function enhanced with mlflow profiling functionality |
Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def mlflow_step_entrypoint(nested: bool = False) -> Callable[[F], F]:
"""Decorator for a step entrypoint to enable mlflow.
Args:
nested: Controls whether to create a run as a child of pipeline run.
All the the mlflow logging functions using during a step with
`nested=True` will be logged into the child run.
Returns:
the input function enhanced with mlflow profiling functionality
"""
def inner_decorator(func: F) -> F:
logger.debug(
"Applying 'mlflow_step_entrypoint' decorator to step entrypoint %s",
func.__name__,
)
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa
logger.debug(
"Setting up MLflow backend before running step entrypoint %s",
func.__name__,
)
experiment_tracker = Repository( # type: ignore[call-arg]
skip_repository_check=True
).active_stack.experiment_tracker
if not isinstance(experiment_tracker, MLFlowExperimentTracker):
raise get_missing_mlflow_experiment_tracker_error()
# Check if there is an active run to nest the run
active_run = experiment_tracker.active_run
if not active_run:
raise RuntimeError("No active mlflow run configured.")
if nested:
active_nested_run = experiment_tracker.active_nested_run
# At this point active_nested_run can never be `None` as this
# would mean that there is not parent active_run, in which case
# the previous runtime error would have been raised. The following
# test is to avoid pylint errors
if not active_nested_run:
raise RuntimeError(
"No active mlflow run configured to create a nested run."
)
with active_run:
with active_nested_run:
return func(*args, **kwargs)
else:
with active_run:
return func(*args, **kwargs)
return cast(F, wrapper)
return inner_decorator
mlflow_utils
Implementation of utils specific to the MLflow integration.
get_missing_mlflow_experiment_tracker_error()
Returns description of how to add an MLflow experiment tracker to your stack.
Returns:
Type | Description |
---|---|
ValueError |
If no MLflow experiment tracker is registered in the active stack. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
"""Returns description of how to add an MLflow experiment tracker to your stack.
Returns:
ValueError: If no MLflow experiment tracker is registered in the active stack.
"""
return ValueError(
"The active stack needs to have a MLflow experiment tracker "
"component registered to be able to track experiments using "
"MLflow. You can create a new stack with a MLflow experiment "
"tracker component or update your existing stack to add this "
"component, e.g.:\n\n"
" 'zenml experiment-tracker register mlflow_tracker "
"--type=mlflow'\n"
" 'zenml stack register stack-name -e mlflow_tracker ...'\n"
)
get_tracking_uri()
Gets the MLflow tracking URI from the active experiment tracking stack component.
noqa: DAR401
Returns:
Type | Description |
---|---|
str |
MLflow tracking URI. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
"""Gets the MLflow tracking URI from the active experiment tracking stack component.
# noqa: DAR401
Returns:
MLflow tracking URI.
"""
tracker = Repository().active_stack.experiment_tracker
if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
raise get_missing_mlflow_experiment_tracker_error()
return tracker.get_tracking_uri()
model_deployers
special
Initialization of the MLflow model deployers.
mlflow_model_deployer
Implementation of the MLflow model deployer.
MLFlowModelDeployer (BaseModelDeployer)
pydantic-model
MLflow implementation of the BaseModelDeployer.
Attributes:
Name | Type | Description |
---|---|---|
service_path |
str |
the path where the local MLflow deployment service |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
"""MLflow implementation of the BaseModelDeployer.
Attributes:
service_path: the path where the local MLflow deployment service
configuration, PID and log files are stored.
"""
service_path: str = ""
# Class Configuration
FLAVOR: ClassVar[str] = MLFLOW_MODEL_DEPLOYER_FLAVOR
@root_validator(skip_on_failure=True)
def set_service_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets the service_path attribute value according to the component UUID.
Args:
values: the dictionary of values to be validated.
Returns:
The validated dictionary of values.
"""
if values.get("service_path"):
return values
# not likely to happen, due to Pydantic validation, but mypy complains
assert "uuid" in values
values["service_path"] = cls.get_service_path(values["uuid"])
return values
@staticmethod
def get_service_path(uuid: uuid.UUID) -> str:
"""Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files are stored.
Args:
uuid: The UUID of the MLflow model deployer.
Returns:
The service path.
"""
service_path = os.path.join(
get_global_config_directory(),
LOCAL_STORES_DIRECTORY_NAME,
str(uuid),
)
create_dir_recursive_if_not_exists(service_path)
return service_path
@property
def local_path(self) -> str:
"""Returns the path to the root directory.
This is where all configurations for MLflow deployment daemon processes are stored.
Returns:
The path to the local service root directory.
"""
return self.service_path
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
A dictionary containing the information.
"""
return {
"PREDICTION_URL": service_instance.endpoint.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SERVICE_PATH": service_instance.status.runtime_path,
"DAEMON_PID": str(service_instance.status.pid),
}
@staticmethod
def get_active_model_deployer() -> "MLFlowModelDeployer":
"""Returns the MLFlowModelDeployer component of the active stack.
Args:
None
Returns:
The MLFlowModelDeployer component of the active stack.
Raises:
TypeError: If the active stack does not contain an MLFlowModelDeployer component.
"""
model_deployer = Repository( # type: ignore[call-arg]
skip_repository_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, MLFlowModelDeployer
):
raise TypeError(
f"The active stack needs to have an MLflow model deployer "
f"component registered to be able to deploy models with MLflow. "
f"You can create a new stack with an MLflow model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
f" 'zenml stack create stack-name -d mlflow ...'\n"
)
return model_deployer
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new MLflow
deployment server to reflect the model and other configuration
parameters specified in the supplied MLflow service `config`.
* if `replace` is True, this method will first attempt to find an
existing MLflow deployment service that is *equivalent* to the
supplied configuration parameters. Two or more MLflow deployment
services are considered equivalent if they have the same
`pipeline_name`, `pipeline_step_name` and `model_name` configuration
parameters. To put it differently, two MLflow deployment services
are equivalent if they serve versions of the same model deployed by
the same pipeline step. If an equivalent MLflow deployment is found,
it will be updated in place to reflect the new configuration
parameters.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Args:
config: the configuration of the model to be deployed with MLflow.
replace: set this flag to True to find and update an equivalent
MLflow deployment server with the new model instead of
creating and starting a new deployment server.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the MLflow
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML MLflow deployment service object that can be used to
interact with the MLflow model server.
"""
config = cast(MLFlowDeploymentConfig, config)
service = None
# if replace is True, remove all existing services
if replace is True:
existing_services = self.find_model_server(
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for existing_service in existing_services:
if service is None:
# keep the most recently created service
service = cast(MLFlowDeploymentService, existing_service)
try:
# delete the older services and don't wait for them to
# be deprovisioned
self._clean_up_existing_service(
existing_service=cast(
MLFlowDeploymentService, existing_service
),
timeout=timeout,
force=True,
)
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
logger.info(
f"Updating an existing MLflow deployment service: {service}"
)
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
service.stop(timeout=timeout, force=True)
service.update(config)
service.start(timeout=timeout)
else:
# create a new MLFlowDeploymentService instance
service = self._create_new_service(timeout, config)
logger.info(f"Created a new MLflow deployment service: {service}")
return cast(BaseService, service)
def _clean_up_existing_service(
self,
timeout: int,
force: bool,
existing_service: MLFlowDeploymentService,
) -> None:
# stop the older service
existing_service.stop(timeout=timeout, force=force)
# delete the old configuration file
service_directory_path = existing_service.status.runtime_path or ""
shutil.rmtree(service_directory_path)
# the step will receive a config from the user that mentions the number
# of workers etc.the step implementation will create a new config using
# all values from the user and add values like pipeline name, model_uri
def _create_new_service(
self, timeout: int, config: MLFlowDeploymentConfig
) -> MLFlowDeploymentService:
"""Creates a new MLFlowDeploymentService.
Args:
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated.
config: the configuration of the model to be deployed with MLflow.
Returns:
The MLFlowDeploymentService object that can be used to interact
with the MLflow model server.
"""
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
# create a new service for the new model
service = MLFlowDeploymentService(config)
service.start(timeout=timeout)
return service
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Finds one or more model servers that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: Name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model
was part of.
pipeline_step_name: The name of the pipeline model deployment step
that deployed the model.
model_name: Name of the deployed model.
model_uri: URI of the deployed model.
model_type: Type/format of the deployed model. Not used in this
MLflow case.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
Raises:
TypeError: if any of the input arguments are of an invalid type.
"""
services = []
config = MLFlowDeploymentConfig(
model_name=model_name or "",
model_uri=model_uri or "",
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
)
# find all services that match the input criteria
for root, _, files in os.walk(self.local_path):
if service_uuid and Path(root).name != str(service_uuid):
continue
for file in files:
if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
service_config_path = os.path.join(root, file)
logger.debug(
"Loading service daemon configuration from %s",
service_config_path,
)
existing_service_config = None
with open(service_config_path, "r") as f:
existing_service_config = f.read()
existing_service = ServiceRegistry().load_service_from_json(
existing_service_config
)
if not isinstance(
existing_service, MLFlowDeploymentService
):
raise TypeError(
f"Expected service type MLFlowDeploymentService but got "
f"{type(existing_service)} instead"
)
existing_service.update_status()
if self._matches_search_criteria(existing_service, config):
if not running or existing_service.is_running:
services.append(cast(BaseService, existing_service))
return services
def _matches_search_criteria(
self,
existing_service: MLFlowDeploymentService,
config: MLFlowDeploymentConfig,
) -> bool:
"""Returns true if a service matches the input criteria.
If any of the values in the input criteria are None, they are ignored.
This allows listing services just by common pipeline names or step
names, etc.
Args:
existing_service: The materialized Service instance derived from
the config of the older (existing) service
config: The MLFlowDeploymentConfig object passed to the
deploy_model function holding parameters of the new service
to be created.
Returns:
True if the service matches the input criteria.
"""
existing_service_config = existing_service.config
# check if the existing service matches the input criteria
if (
(
not config.pipeline_name
or existing_service_config.pipeline_name == config.pipeline_name
)
and (
not config.model_name
or existing_service_config.model_name == config.model_name
)
and (
not config.pipeline_step_name
or existing_service_config.pipeline_step_name
== config.pipeline_step_name
)
and (
not config.pipeline_run_id
or existing_service_config.pipeline_run_id
== config.pipeline_run_id
)
):
return True
return False
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to stop a model server.
Args:
uuid: UUID of the model server to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, stop it
if existing_services:
existing_services[0].stop(timeout=timeout, force=force)
def start_model_server(
self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
"""Method to start a model server.
Args:
uuid: UUID of the model server to start.
timeout: Timeout in seconds to wait for the service to start.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, start it
if existing_services:
existing_services[0].start(timeout=timeout)
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
uuid: UUID of the model server to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, clean it up
if existing_services:
service = cast(MLFlowDeploymentService, existing_services[0])
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
local_path: str
property
readonly
Returns the path to the root directory.
This is where all configurations for MLflow deployment daemon processes are stored.
Returns:
Type | Description |
---|---|
str |
The path to the local service root directory. |
delete_model_server(self, uuid, timeout=10, force=False)
Method to delete all configuration of a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to delete. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
10 |
force |
bool |
If True, force the service to stop. |
False |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
uuid: UUID of the model server to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, clean it up
if existing_services:
service = cast(MLFlowDeploymentService, existing_services[0])
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
deploy_model(self, config, replace=False, timeout=10)
Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow serviceconfig
. -
if
replace
is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServiceConfig |
the configuration of the model to be deployed with MLflow. |
required |
replace |
bool |
set this flag to True to find and update an equivalent MLflow deployment server with the new model instead of creating and starting a new deployment server. |
False |
timeout |
int |
the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start. |
10 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML MLflow deployment service object that can be used to interact with the MLflow model server. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new MLflow
deployment server to reflect the model and other configuration
parameters specified in the supplied MLflow service `config`.
* if `replace` is True, this method will first attempt to find an
existing MLflow deployment service that is *equivalent* to the
supplied configuration parameters. Two or more MLflow deployment
services are considered equivalent if they have the same
`pipeline_name`, `pipeline_step_name` and `model_name` configuration
parameters. To put it differently, two MLflow deployment services
are equivalent if they serve versions of the same model deployed by
the same pipeline step. If an equivalent MLflow deployment is found,
it will be updated in place to reflect the new configuration
parameters.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Args:
config: the configuration of the model to be deployed with MLflow.
replace: set this flag to True to find and update an equivalent
MLflow deployment server with the new model instead of
creating and starting a new deployment server.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the MLflow
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML MLflow deployment service object that can be used to
interact with the MLflow model server.
"""
config = cast(MLFlowDeploymentConfig, config)
service = None
# if replace is True, remove all existing services
if replace is True:
existing_services = self.find_model_server(
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for existing_service in existing_services:
if service is None:
# keep the most recently created service
service = cast(MLFlowDeploymentService, existing_service)
try:
# delete the older services and don't wait for them to
# be deprovisioned
self._clean_up_existing_service(
existing_service=cast(
MLFlowDeploymentService, existing_service
),
timeout=timeout,
force=True,
)
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
logger.info(
f"Updating an existing MLflow deployment service: {service}"
)
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
service.stop(timeout=timeout, force=True)
service.update(config)
service.start(timeout=timeout)
else:
# create a new MLFlowDeploymentService instance
service = self._create_new_service(timeout, config)
logger.info(f"Created a new MLflow deployment service: {service}")
return cast(BaseService, service)
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)
Finds one or more model servers that match the given criteria.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running |
bool |
If true, only running services will be returned. |
False |
service_uuid |
Optional[uuid.UUID] |
The UUID of the service that was originally used to deploy the model. |
None |
pipeline_name |
Optional[str] |
Name of the pipeline that the deployed model was part of. |
None |
pipeline_run_id |
Optional[str] |
ID of the pipeline run which the deployed model was part of. |
None |
pipeline_step_name |
Optional[str] |
The name of the pipeline model deployment step that deployed the model. |
None |
model_name |
Optional[str] |
Name of the deployed model. |
None |
model_uri |
Optional[str] |
URI of the deployed model. |
None |
model_type |
Optional[str] |
Type/format of the deployed model. Not used in this MLflow case. |
None |
Returns:
Type | Description |
---|---|
List[zenml.services.service.BaseService] |
One or more Service objects representing model servers that match the input search criteria. |
Exceptions:
Type | Description |
---|---|
TypeError |
if any of the input arguments are of an invalid type. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Finds one or more model servers that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: Name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model
was part of.
pipeline_step_name: The name of the pipeline model deployment step
that deployed the model.
model_name: Name of the deployed model.
model_uri: URI of the deployed model.
model_type: Type/format of the deployed model. Not used in this
MLflow case.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
Raises:
TypeError: if any of the input arguments are of an invalid type.
"""
services = []
config = MLFlowDeploymentConfig(
model_name=model_name or "",
model_uri=model_uri or "",
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
)
# find all services that match the input criteria
for root, _, files in os.walk(self.local_path):
if service_uuid and Path(root).name != str(service_uuid):
continue
for file in files:
if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
service_config_path = os.path.join(root, file)
logger.debug(
"Loading service daemon configuration from %s",
service_config_path,
)
existing_service_config = None
with open(service_config_path, "r") as f:
existing_service_config = f.read()
existing_service = ServiceRegistry().load_service_from_json(
existing_service_config
)
if not isinstance(
existing_service, MLFlowDeploymentService
):
raise TypeError(
f"Expected service type MLFlowDeploymentService but got "
f"{type(existing_service)} instead"
)
existing_service.update_status()
if self._matches_search_criteria(existing_service, config):
if not running or existing_service.is_running:
services.append(cast(BaseService, existing_service))
return services
get_active_model_deployer()
staticmethod
Returns the MLFlowModelDeployer component of the active stack.
Returns:
Type | Description |
---|---|
MLFlowModelDeployer |
The MLFlowModelDeployer component of the active stack. |
Exceptions:
Type | Description |
---|---|
TypeError |
If the active stack does not contain an MLFlowModelDeployer component. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "MLFlowModelDeployer":
"""Returns the MLFlowModelDeployer component of the active stack.
Args:
None
Returns:
The MLFlowModelDeployer component of the active stack.
Raises:
TypeError: If the active stack does not contain an MLFlowModelDeployer component.
"""
model_deployer = Repository( # type: ignore[call-arg]
skip_repository_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, MLFlowModelDeployer
):
raise TypeError(
f"The active stack needs to have an MLflow model deployer "
f"component registered to be able to deploy models with MLflow. "
f"You can create a new stack with an MLflow model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
f" 'zenml stack create stack-name -d mlflow ...'\n"
)
return model_deployer
get_model_server_info(service_instance)
staticmethod
Return implementation specific information relevant to the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
MLFlowDeploymentService |
Instance of a SeldonDeploymentService |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
A dictionary containing the information. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
A dictionary containing the information.
"""
return {
"PREDICTION_URL": service_instance.endpoint.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SERVICE_PATH": service_instance.status.runtime_path,
"DAEMON_PID": str(service_instance.status.pid),
}
get_service_path(uuid)
staticmethod
Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files are stored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
The UUID of the MLflow model deployer. |
required |
Returns:
Type | Description |
---|---|
str |
The service path. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(uuid: uuid.UUID) -> str:
"""Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files are stored.
Args:
uuid: The UUID of the MLflow model deployer.
Returns:
The service path.
"""
service_path = os.path.join(
get_global_config_directory(),
LOCAL_STORES_DIRECTORY_NAME,
str(uuid),
)
create_dir_recursive_if_not_exists(service_path)
return service_path
set_service_path(values)
classmethod
Sets the service_path attribute value according to the component UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
Dict[str, Any] |
the dictionary of values to be validated. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The validated dictionary of values. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@root_validator(skip_on_failure=True)
def set_service_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets the service_path attribute value according to the component UUID.
Args:
values: the dictionary of values to be validated.
Returns:
The validated dictionary of values.
"""
if values.get("service_path"):
return values
# not likely to happen, due to Pydantic validation, but mypy complains
assert "uuid" in values
values["service_path"] = cls.get_service_path(values["uuid"])
return values
start_model_server(self, uuid, timeout=10)
Method to start a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to start. |
required |
timeout |
int |
Timeout in seconds to wait for the service to start. |
10 |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def start_model_server(
self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
"""Method to start a model server.
Args:
uuid: UUID of the model server to start.
timeout: Timeout in seconds to wait for the service to start.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, start it
if existing_services:
existing_services[0].start(timeout=timeout)
stop_model_server(self, uuid, timeout=10, force=False)
Method to stop a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to stop. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
10 |
force |
bool |
If True, force the service to stop. |
False |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to stop a model server.
Args:
uuid: UUID of the model server to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, stop it
if existing_services:
existing_services[0].stop(timeout=timeout, force=force)
services
special
Initialization of the MLflow Service.
mlflow_deployment
Implementation of the MLflow deployment functionality.
MLFlowDeploymentConfig (LocalDaemonServiceConfig)
pydantic-model
MLflow model deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the MLflow model to serve |
model_name |
str |
the name of the model |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
"""MLflow model deployment configuration.
Attributes:
model_uri: URI of the MLflow model to serve
model_name: the name of the model
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
"""
model_uri: str
model_name: str
workers: int = 1
mlserver: bool = False
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint)
pydantic-model
A service endpoint exposed by the MLflow deployment daemon.
Attributes:
Name | Type | Description |
---|---|---|
config |
MLFlowDeploymentEndpointConfig |
service endpoint configuration |
monitor |
HTTPEndpointHealthMonitor |
optional service endpoint health monitor |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
"""A service endpoint exposed by the MLflow deployment daemon.
Attributes:
config: service endpoint configuration
monitor: optional service endpoint health monitor
"""
config: MLFlowDeploymentEndpointConfig
monitor: HTTPEndpointHealthMonitor
@property
def prediction_url(self) -> Optional[str]:
"""Gets the prediction URL for the endpoint.
Returns:
the prediction URL for the endpoint
"""
uri = self.status.uri
if not uri:
return None
return f"{uri}{self.config.prediction_url_path}"
prediction_url: Optional[str]
property
readonly
Gets the prediction URL for the endpoint.
Returns:
Type | Description |
---|---|
Optional[str] |
the prediction URL for the endpoint |
MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig)
pydantic-model
MLflow daemon service endpoint configuration.
Attributes:
Name | Type | Description |
---|---|---|
prediction_url_path |
str |
URI subpath for prediction requests |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
"""MLflow daemon service endpoint configuration.
Attributes:
prediction_url_path: URI subpath for prediction requests
"""
prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService)
pydantic-model
MLflow deployment service used to start a local prediction server for MLflow models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the MLflow deployment service class |
config |
MLFlowDeploymentConfig |
service configuration |
endpoint |
MLFlowDeploymentEndpoint |
optional service endpoint |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService):
"""MLflow deployment service used to start a local prediction server for MLflow models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the MLflow deployment service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="mlflow-deployment",
type="model-serving",
flavor="mlflow",
description="MLflow prediction service",
)
config: MLFlowDeploymentConfig
endpoint: MLFlowDeploymentEndpoint
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialize the MLflow deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_url_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_url_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_url_path=prediction_url_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
"""Start the service."""
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
serve_kwargs: Dict[str, Any] = {}
# MLflow version 1.26 introduces an additional mandatory
# `timeout` argument to the `PyFuncBackend.serve` function
if int(MLFLOW_VERSION.split(".")[1]) >= 26:
serve_kwargs["timeout"] = None
backend = PyFuncBackend(
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
)
backend.serve(
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
**serve_kwargs,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
@property
def prediction_url(self) -> Optional[str]:
"""Get the URI where the prediction service is answering requests.
Returns:
The URI where the prediction service can be contacted to process
HTTP/REST inference requests, or None, if the service isn't running.
"""
if not self.is_running:
return None
return self.endpoint.prediction_url
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not running
ValueError: if the prediction endpoint is unknown.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
if self.endpoint.prediction_url is not None:
response = requests.post(
self.endpoint.prediction_url,
json={"instances": request.tolist()},
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
return np.array(response.json())
prediction_url: Optional[str]
property
readonly
Get the URI where the prediction service is answering requests.
Returns:
Type | Description |
---|---|
Optional[str] |
The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running. |
__init__(self, config, **attrs)
special
Initialize the MLflow deployment service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]] |
service configuration |
required |
attrs |
Any |
additional attributes to set on the service |
{} |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialize the MLflow deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_url_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_url_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_url_path=prediction_url_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
NDArray[Any] |
a numpy array representing the request |
required |
Returns:
Type | Description |
---|---|
NDArray[Any] |
A numpy array representing the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not running |
ValueError |
if the prediction endpoint is unknown. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not running
ValueError: if the prediction endpoint is unknown.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
if self.endpoint.prediction_url is not None:
response = requests.post(
self.endpoint.prediction_url,
json={"instances": request.tolist()},
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
return np.array(response.json())
run(self)
Start the service.
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
"""Start the service."""
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
serve_kwargs: Dict[str, Any] = {}
# MLflow version 1.26 introduces an additional mandatory
# `timeout` argument to the `PyFuncBackend.serve` function
if int(MLFLOW_VERSION.split(".")[1]) >= 26:
serve_kwargs["timeout"] = None
backend = PyFuncBackend(
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
)
backend.serve(
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
**serve_kwargs,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
steps
special
Initialization of the MLflow standard interface steps.
mlflow_deployer
Implementation of the MLflow model deployer pipeline step.
MLFlowDeployerConfig (BaseStepConfig)
pydantic-model
Model deployer step configuration for MLflow.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
the name of the MLflow model logged in the MLflow artifact store for the current pipeline. |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
timeout |
int |
the number of seconds to wait for the service to start/stop. |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
"""Model deployer step configuration for MLflow.
Attributes:
model_name: the name of the MLflow model logged in the MLflow artifact
store for the current pipeline.
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
timeout: the number of seconds to wait for the service to start/stop.
"""
model_name: str = "model"
model_uri: str = ""
workers: int = 1
mlserver: bool = False
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
mlflow_model_deployer_step (BaseStep)
Model deployer pipeline step for MLflow.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
model |
the model artifact to deploy |
required | |
config |
configuration for the deployer step |
required |
Returns:
Type | Description |
---|---|
MLflow deployment service |
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Model deployer step configuration for MLflow.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
the name of the MLflow model logged in the MLflow artifact store for the current pipeline. |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
timeout |
int |
the number of seconds to wait for the service to start/stop. |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
"""Model deployer step configuration for MLflow.
Attributes:
model_name: the name of the MLflow model logged in the MLflow artifact
store for the current pipeline.
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
timeout: the number of seconds to wait for the service to start/stop.
"""
model_name: str = "model"
model_uri: str = ""
workers: int = 1
mlserver: bool = False
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
entrypoint(deploy_decision, model, config)
staticmethod
Model deployer pipeline step for MLflow.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
config |
MLFlowDeployerConfig |
configuration for the deployer step |
required |
Returns:
Type | Description |
---|---|
MLFlowDeploymentService |
MLflow deployment service |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
@enable_mlflow
@step(enable_cache=False)
def mlflow_model_deployer_step(
deploy_decision: bool,
model: ModelArtifact,
config: MLFlowDeployerConfig,
) -> MLFlowDeploymentService:
"""Model deployer pipeline step for MLflow.
# noqa: DAR401
Args:
deploy_decision: whether to deploy the model or not
model: the model artifact to deploy
config: configuration for the deployer step
Returns:
MLflow deployment service
"""
model_deployer = MLFlowModelDeployer.get_active_model_deployer()
# fetch the MLflow artifacts logged during the pipeline run
experiment_tracker = Repository( # type: ignore[call-arg]
skip_repository_check=True
).active_stack.experiment_tracker
if not isinstance(experiment_tracker, MLFlowExperimentTracker):
raise get_missing_mlflow_experiment_tracker_error()
client = MlflowClient()
model_uri = ""
mlflow_run = experiment_tracker.active_run
if mlflow_run and client.list_artifacts(
mlflow_run.info.run_id, config.model_name
):
model_uri = get_artifact_uri(config.model_name)
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# fetch existing services with same pipeline name, step name and model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=config.model_name,
)
# create a config for the new model service
predictor_cfg = MLFlowDeploymentConfig(
model_name=config.model_name or "",
model_uri=model_uri,
workers=config.workers,
mlserver=config.mlserver,
pipeline_name=pipeline_name,
pipeline_run_id=run_id,
pipeline_step_name=step_name,
)
# Creating a new service with inactive state and status by default
service = MLFlowDeploymentService(predictor_cfg)
if existing_services:
service = cast(MLFlowDeploymentService, existing_services[0])
# check for conditions to deploy the model
if not model_uri:
# an MLflow model was not trained in the current run, so we simply reuse
# the currently running service created for the same model, if any
if not existing_services:
logger.warning(
f"An MLflow model with name `{config.model_name}` was not "
f"logged in the current pipeline run and no running MLflow "
f"model server was found. Please ensure that your pipeline "
f"includes an `@enable_mlflow` decorated step that trains a "
f"model and logs it to MLflow. This could also happen if "
f"the current pipeline run did not log an MLflow model "
f"because the training step was cached."
)
# return an inactive service just because we have to return
# something
return service
logger.info(
f"An MLflow model with name `{config.model_name}` was not "
f"trained in the current pipeline run. Reusing the existing "
f"MLflow model server."
)
if not service.is_running:
service.start(config.timeout)
# return the existing service
return service
# even when the deploy decision is negative, if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{config.model_name}'..."
)
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure
# that a model server is available at all times
if not service.is_running:
service.start(config.timeout)
return service
# create a new model deployment and replace an old one if it exists
new_service = cast(
MLFlowDeploymentService,
model_deployer.deploy_model(
replace=True,
config=predictor_cfg,
timeout=config.timeout,
),
)
logger.info(
f"MLflow deployment service started and reachable at:\n"
f" {new_service.prediction_url}\n"
)
return new_service
mlflow_deployer_step(enable_cache=True, name=None)
Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.
The returned step can be used in a pipeline to implement continuous deployment for an MLflow model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
enable_cache |
bool |
Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default |
True |
name |
Optional[str] |
Name of the step. |
None |
Returns:
Type | Description |
---|---|
Type[zenml.steps.base_step.BaseStep] |
an MLflow model deployer pipeline step |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
def mlflow_deployer_step(
enable_cache: bool = True,
name: Optional[str] = None,
) -> Type[BaseStep]:
"""Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.
The returned step can be used in a pipeline to implement continuous
deployment for an MLflow model.
Args:
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default
name: Name of the step.
Returns:
an MLflow model deployer pipeline step
"""
logger.warning(
"The `mlflow_deployer_step` function is deprecated. Please "
"use the built-in `mlflow_model_deployer_step` step instead."
)
return mlflow_model_deployer_step
neural_prophet
special
Initialization of the Neural Prophet integration.
NeuralProphetIntegration (Integration)
Definition of NeuralProphet integration for ZenML.
Source code in zenml/integrations/neural_prophet/__init__.py
class NeuralProphetIntegration(Integration):
"""Definition of NeuralProphet integration for ZenML."""
NAME = NEURAL_PROPHET
REQUIREMENTS = ["neuralprophet>=0.3.2"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.neural_prophet import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/neural_prophet/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.neural_prophet import materializers # noqa
materializers
special
Initialization of the Neural Prophet materializer.
neural_prophet_materializer
Implementation of the Neural Prophet materializer.
NeuralProphetMaterializer (BaseMaterializer)
Materializer to read/write NeuralProphet models.
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
class NeuralProphetMaterializer(BaseMaterializer):
"""Materializer to read/write NeuralProphet models."""
ASSOCIATED_TYPES = (NeuralProphet,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
"""Reads and returns a NeuralProphet model.
Args:
data_type: A NeuralProphet model object.
Returns:
A loaded NeuralProphet model.
"""
super().handle_input(data_type)
return torch.load( # type: ignore[no-untyped-call]
os.path.join(self.artifact.uri, DEFAULT_FILENAME)
) # noqa
def handle_return(self, model: NeuralProphet) -> None:
"""Writes a NeuralProphet model.
Args:
model: A NeuralProphet model object.
"""
super().handle_return(model)
torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
handle_input(self, data_type)
Reads and returns a NeuralProphet model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A NeuralProphet model object. |
required |
Returns:
Type | Description |
---|---|
NeuralProphet |
A loaded NeuralProphet model. |
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
"""Reads and returns a NeuralProphet model.
Args:
data_type: A NeuralProphet model object.
Returns:
A loaded NeuralProphet model.
"""
super().handle_input(data_type)
return torch.load( # type: ignore[no-untyped-call]
os.path.join(self.artifact.uri, DEFAULT_FILENAME)
) # noqa
handle_return(self, model)
Writes a NeuralProphet model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
NeuralProphet |
A NeuralProphet model object. |
required |
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_return(self, model: NeuralProphet) -> None:
"""Writes a NeuralProphet model.
Args:
model: A NeuralProphet model object.
"""
super().handle_return(model)
torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
plotly
special
Initialization of the Plotly integration.
PlotlyIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/plotly/__init__.py
class PlotlyIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = PLOTLY
REQUIREMENTS = ["plotly>=5.4.0"]
visualizers
special
Initialization of the Plotly Visualizer.
pipeline_lineage_visualizer
Implementation of the Plotly Pipeline Lineage Visualizer.
PipelineLineageVisualizer (BasePipelineVisualizer)
Visualize the lineage of runs in a pipeline using plotly.
Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
class PipelineLineageVisualizer(BasePipelineVisualizer):
"""Visualize the lineage of runs in a pipeline using plotly."""
@abstractmethod
def visualize(
self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
"""Creates a pipeline lineage diagram using plotly.
Args:
object: The pipeline view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A plotly figure.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
category_dict = {}
dimensions = ["run"]
for run in object.runs:
category_dict[run.name] = {"run": run.name}
for step in run.steps:
category_dict[run.name].update(
{
step.entrypoint_name: str(step.id),
}
)
if step.entrypoint_name not in dimensions:
dimensions.append(f"{step.entrypoint_name}")
category_df = pd.DataFrame.from_dict(category_dict, orient="index")
category_df = category_df.reset_index()
fig = px.parallel_categories(
category_df,
dimensions,
color=None,
labels="status",
)
fig.show()
return fig
visualize(self, object, *args, **kwargs)
Creates a pipeline lineage diagram using plotly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineView |
The pipeline view to visualize. |
required |
*args |
Any |
Additional arguments to pass to the visualization. |
() |
**kwargs |
Any |
Additional keyword arguments to pass to the visualization. |
{} |
Returns:
Type | Description |
---|---|
Figure |
A plotly figure. |
Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
@abstractmethod
def visualize(
self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
"""Creates a pipeline lineage diagram using plotly.
Args:
object: The pipeline view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A plotly figure.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
category_dict = {}
dimensions = ["run"]
for run in object.runs:
category_dict[run.name] = {"run": run.name}
for step in run.steps:
category_dict[run.name].update(
{
step.entrypoint_name: str(step.id),
}
)
if step.entrypoint_name not in dimensions:
dimensions.append(f"{step.entrypoint_name}")
category_df = pd.DataFrame.from_dict(category_dict, orient="index")
category_df = category_df.reset_index()
fig = px.parallel_categories(
category_df,
dimensions,
color=None,
labels="status",
)
fig.show()
return fig
pytorch
special
Initialization of the PyTorch integration.
PytorchIntegration (Integration)
Definition of PyTorch integration for ZenML.
Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
"""Definition of PyTorch integration for ZenML."""
NAME = PYTORCH
REQUIREMENTS = ["torch"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
materializers
special
Initialization of the PyTorch Materializer.
pytorch_dataloader_materializer
Implementation of the PyTorch DataLoader materializer.
PyTorchDataLoaderMaterializer (BaseMaterializer)
Materializer to read/write PyTorch dataloaders.
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BaseMaterializer):
"""Materializer to read/write PyTorch dataloaders."""
ASSOCIATED_TYPES = (DataLoader,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
"""Reads and returns a PyTorch dataloader.
Args:
data_type: The type of the dataloader to load.
Returns:
A loaded PyTorch dataloader.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(DataLoader[Any], torch.load(f)) # type: ignore[no-untyped-call] # noqa
def handle_return(self, dataloader: DataLoader[Any]) -> None:
"""Writes a PyTorch dataloader.
Args:
dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
"""
super().handle_return(dataloader)
# Save entire dataloader to artifact directory
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(dataloader, f)
handle_input(self, data_type)
Reads and returns a PyTorch dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the dataloader to load. |
required |
Returns:
Type | Description |
---|---|
torch.utils.data.dataloader.DataLoader[Any] |
A loaded PyTorch dataloader. |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
"""Reads and returns a PyTorch dataloader.
Args:
data_type: The type of the dataloader to load.
Returns:
A loaded PyTorch dataloader.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(DataLoader[Any], torch.load(f)) # type: ignore[no-untyped-call] # noqa
handle_return(self, dataloader)
Writes a PyTorch dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataloader |
torch.utils.data.dataloader.DataLoader[Any] |
A torch.utils.DataLoader or a dict to pass into dataloader.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_return(self, dataloader: DataLoader[Any]) -> None:
"""Writes a PyTorch dataloader.
Args:
dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
"""
super().handle_return(dataloader)
# Save entire dataloader to artifact directory
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(dataloader, f)
pytorch_module_materializer
Implementation of the PyTorch Module materializer.
PyTorchModuleMaterializer (BaseMaterializer)
Materializer to read/write Pytorch models.
Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BaseMaterializer):
"""Materializer to read/write Pytorch models.
Inspired by the guide:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
"""
ASSOCIATED_TYPES = (Module,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Args:
data_type: The type of the model to load.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return torch.load(f) # type: ignore[no-untyped-call] # noqa
def handle_return(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(model, f)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f)
handle_input(self, data_type)
Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to load. |
required |
Returns:
Type | Description |
---|---|
Module |
A loaded pytorch model. |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_input(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Args:
data_type: The type of the model to load.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return torch.load(f) # type: ignore[no-untyped-call] # noqa
handle_return(self, model)
Writes a PyTorch model, as a model and a checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module |
A torch.nn.Module or a dict to pass into model.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_return(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(model, f)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f)
pytorch_lightning
special
Initialization of the PyTorch Lightning integration.
PytorchLightningIntegration (Integration)
Definition of PyTorch Lightning integration for ZenML.
Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
"""Definition of PyTorch Lightning integration for ZenML."""
NAME = PYTORCH_L
REQUIREMENTS = ["pytorch_lightning"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
materializers
special
Initialization of the PyTorch Lightning Materializer.
pytorch_lightning_materializer
Implementation of the PyTorch Lightning Materializer.
PyTorchLightningMaterializer (BaseMaterializer)
Materializer to read/write PyTorch models.
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
"""Materializer to read/write PyTorch models."""
ASSOCIATED_TYPES = (Trainer,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Trainer:
"""Reads and returns a PyTorch Lightning trainer.
Args:
data_type: The type of the trainer to load.
Returns:
A PyTorch Lightning trainer object.
"""
super().handle_input(data_type)
return Trainer(
resume_from_checkpoint=os.path.join(
self.artifact.uri, CHECKPOINT_NAME
)
)
def handle_return(self, trainer: Trainer) -> None:
"""Writes a PyTorch Lightning trainer.
Args:
trainer: A PyTorch Lightning trainer object.
"""
super().handle_return(trainer)
trainer.save_checkpoint(
os.path.join(self.artifact.uri, CHECKPOINT_NAME)
)
handle_input(self, data_type)
Reads and returns a PyTorch Lightning trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the trainer to load. |
required |
Returns:
Type | Description |
---|---|
Trainer |
A PyTorch Lightning trainer object. |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_input(self, data_type: Type[Any]) -> Trainer:
"""Reads and returns a PyTorch Lightning trainer.
Args:
data_type: The type of the trainer to load.
Returns:
A PyTorch Lightning trainer object.
"""
super().handle_input(data_type)
return Trainer(
resume_from_checkpoint=os.path.join(
self.artifact.uri, CHECKPOINT_NAME
)
)
handle_return(self, trainer)
Writes a PyTorch Lightning trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
Trainer |
A PyTorch Lightning trainer object. |
required |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_return(self, trainer: Trainer) -> None:
"""Writes a PyTorch Lightning trainer.
Args:
trainer: A PyTorch Lightning trainer object.
"""
super().handle_return(trainer)
trainer.save_checkpoint(
os.path.join(self.artifact.uri, CHECKPOINT_NAME)
)
registry
Implementation of a registry to track ZenML integrations.
IntegrationRegistry
Registry to keep track of ZenML Integrations.
Source code in zenml/integrations/registry.py
class IntegrationRegistry(object):
"""Registry to keep track of ZenML Integrations."""
def __init__(self) -> None:
"""Initializing the integration registry."""
self._integrations: Dict[str, Type["Integration"]] = {}
@property
def integrations(self) -> Dict[str, Type["Integration"]]:
"""Method to get integrations dictionary.
Returns:
A dict of integration key to type of `Integration`.
"""
return self._integrations
@integrations.setter
def integrations(self, i: Any) -> None:
"""Setter method for the integrations property.
Args:
i: Value to set the integrations property to.
Raises:
IntegrationError: If you try to manually set the integrations property.
"""
raise IntegrationError(
"Please do not manually change the integrations within the "
"registry. If you would like to register a new integration "
"manually, please use "
"`integration_registry.register_integration()`."
)
def register_integration(
self, key: str, type_: Type["Integration"]
) -> None:
"""Method to register an integration with a given name.
Args:
key: Name of the integration.
type_: Type of the integration.
"""
self._integrations[key] = type_
def activate_integrations(self) -> None:
"""Method to activate the integrations with are registered in the registry."""
for name, integration in self._integrations.items():
if integration.check_installation():
integration.activate()
logger.debug(f"Integration `{name}` is activated.")
else:
logger.debug(f"Integration `{name}` could not be activated.")
@property
def list_integration_names(self) -> List[str]:
"""Get a list of all possible integrations.
Returns:
A list of all possible integrations.
"""
return [name for name in self._integrations]
def select_integration_requirements(
self, integration_name: Optional[str] = None
) -> List[str]:
"""Select the requirements for a given integration or all integrations.
Args:
integration_name: Name of the integration to check.
Returns:
List of requirements for the integration.
Raises:
KeyError: If the integration is not found.
"""
if integration_name:
if integration_name in self.list_integration_names:
return self._integrations[integration_name].REQUIREMENTS
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
else:
return [
requirement
for name in self.list_integration_names
for requirement in self._integrations[name].REQUIREMENTS
]
def is_installed(self, integration_name: Optional[str] = None) -> bool:
"""Checks if all requirements for an integration are installed.
Args:
integration_name: Name of the integration to check.
Returns:
True if all requirements are installed, False otherwise.
Raises:
KeyError: If the integration is not found.
"""
if integration_name in self.list_integration_names:
return self._integrations[integration_name].check_installation()
elif not integration_name:
all_installed = [
self._integrations[item].check_installation()
for item in self.list_integration_names
]
return all(all_installed)
else:
raise KeyError(
f"Integration '{integration_name}' not found. "
f"Currently the following integrations are available: "
f"{self.list_integration_names}"
)
def get_installed_integrations(self) -> List[str]:
"""Returns list of installed integrations.
Returns:
List of installed integrations.
"""
return [
name
for name, integration in integration_registry.integrations.items()
if integration.check_installation()
]
integrations: Dict[str, Type[Integration]]
property
writable
Method to get integrations dictionary.
Returns:
Type | Description |
---|---|
Dict[str, Type[Integration]] |
A dict of integration key to type of |
list_integration_names: List[str]
property
readonly
Get a list of all possible integrations.
Returns:
Type | Description |
---|---|
List[str] |
A list of all possible integrations. |
__init__(self)
special
Initializing the integration registry.
Source code in zenml/integrations/registry.py
def __init__(self) -> None:
"""Initializing the integration registry."""
self._integrations: Dict[str, Type["Integration"]] = {}
activate_integrations(self)
Method to activate the integrations with are registered in the registry.
Source code in zenml/integrations/registry.py
def activate_integrations(self) -> None:
"""Method to activate the integrations with are registered in the registry."""
for name, integration in self._integrations.items():
if integration.check_installation():
integration.activate()
logger.debug(f"Integration `{name}` is activated.")
else:
logger.debug(f"Integration `{name}` could not be activated.")
get_installed_integrations(self)
Returns list of installed integrations.
Returns:
Type | Description |
---|---|
List[str] |
List of installed integrations. |
Source code in zenml/integrations/registry.py
def get_installed_integrations(self) -> List[str]:
"""Returns list of installed integrations.
Returns:
List of installed integrations.
"""
return [
name
for name, integration in integration_registry.integrations.items()
if integration.check_installation()
]
is_installed(self, integration_name=None)
Checks if all requirements for an integration are installed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
integration_name |
Optional[str] |
Name of the integration to check. |
None |
Returns:
Type | Description |
---|---|
bool |
True if all requirements are installed, False otherwise. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the integration is not found. |
Source code in zenml/integrations/registry.py
def is_installed(self, integration_name: Optional[str] = None) -> bool:
"""Checks if all requirements for an integration are installed.
Args:
integration_name: Name of the integration to check.
Returns:
True if all requirements are installed, False otherwise.
Raises:
KeyError: If the integration is not found.
"""
if integration_name in self.list_integration_names:
return self._integrations[integration_name].check_installation()
elif not integration_name:
all_installed = [
self._integrations[item].check_installation()
for item in self.list_integration_names
]
return all(all_installed)
else:
raise KeyError(
f"Integration '{integration_name}' not found. "
f"Currently the following integrations are available: "
f"{self.list_integration_names}"
)
register_integration(self, key, type_)
Method to register an integration with a given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Name of the integration. |
required |
type_ |
Type[Integration] |
Type of the integration. |
required |
Source code in zenml/integrations/registry.py
def register_integration(
self, key: str, type_: Type["Integration"]
) -> None:
"""Method to register an integration with a given name.
Args:
key: Name of the integration.
type_: Type of the integration.
"""
self._integrations[key] = type_
select_integration_requirements(self, integration_name=None)
Select the requirements for a given integration or all integrations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
integration_name |
Optional[str] |
Name of the integration to check. |
None |
Returns:
Type | Description |
---|---|
List[str] |
List of requirements for the integration. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the integration is not found. |
Source code in zenml/integrations/registry.py
def select_integration_requirements(
self, integration_name: Optional[str] = None
) -> List[str]:
"""Select the requirements for a given integration or all integrations.
Args:
integration_name: Name of the integration to check.
Returns:
List of requirements for the integration.
Raises:
KeyError: If the integration is not found.
"""
if integration_name:
if integration_name in self.list_integration_names:
return self._integrations[integration_name].REQUIREMENTS
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
else:
return [
requirement
for name in self.list_integration_names
for requirement in self._integrations[name].REQUIREMENTS
]
s3
special
Initialization of the S3 integration.
The S3 integration allows the use of cloud artifact stores and file operations on S3 buckets.
S3Integration (Integration)
Definition of S3 integration for ZenML.
Source code in zenml/integrations/s3/__init__.py
class S3Integration(Integration):
"""Definition of S3 integration for ZenML."""
NAME = S3
REQUIREMENTS = ["s3fs==2022.3.0"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the s3 integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=S3_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.s3.artifact_stores.S3ArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the s3 integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/s3/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the s3 integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=S3_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.s3.artifact_stores.S3ArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
)
]
artifact_stores
special
Initialization of the S3 Artifact Store.
s3_artifact_store
Implementation of the S3 Artifact Store.
S3ArtifactStore (BaseArtifactStore, AuthenticationMixin)
pydantic-model
Artifact Store for S3 based artifacts.
All attributes of this class except path
will be passed to the
s3fs.S3FileSystem
initialization. See
here for more information on how
to use those configuration options to connect to any S3-compatible storage.
When you want to register an S3ArtifactStore from the CLI and need to pass
client_kwargs
, config_kwargs
or s3_additional_kwargs
, you should pass
them as a json string:
zenml artifact-store register my_s3_store --type=s3 --path=s3://my_bucket --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
class S3ArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for S3 based artifacts.
All attributes of this class except `path` will be passed to the
`s3fs.S3FileSystem` initialization. See
[here](https://s3fs.readthedocs.io/en/latest/) for more information on how
to use those configuration options to connect to any S3-compatible storage.
When you want to register an S3ArtifactStore from the CLI and need to pass
`client_kwargs`, `config_kwargs` or `s3_additional_kwargs`, you should pass
them as a json string:
```
zenml artifact-store register my_s3_store --type=s3 --path=s3://my_bucket \
--client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
```
"""
key: Optional[str] = None
secret: Optional[str] = None
token: Optional[str] = None
client_kwargs: Optional[Dict[str, Any]] = None
config_kwargs: Optional[Dict[str, Any]] = None
s3_additional_kwargs: Optional[Dict[str, Any]] = None
_filesystem: Optional[s3fs.S3FileSystem] = None
# Class variables
FLAVOR: ClassVar[str] = S3_ARTIFACT_STORE_FLAVOR
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"s3://"}
@validator(
"client_kwargs", "config_kwargs", "s3_additional_kwargs", pre=True
)
def _convert_json_string(
cls, value: Union[None, str, Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Converts potential JSON strings passed via the CLI to dictionaries.
Args:
value: The value to convert.
Returns:
The converted value.
Raises:
TypeError: If the value is not a `str`, `Dict` or `None`.
ValueError: If the value is an invalid json string or a json string
that does not decode into a dictionary.
"""
if isinstance(value, str):
try:
dict_ = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid json string '{value}'") from e
if not isinstance(dict_, Dict):
raise ValueError(
f"Json string '{value}' did not decode into a dictionary."
)
return dict_
elif isinstance(value, Dict) or value is None:
return value
else:
raise TypeError(f"{value} is not a json string or a dictionary.")
def _get_credentials(
self,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Gets authentication credentials.
If an authentication secret is configured, the secret values are
returned. Otherwise we fallback to the plain text component attributes.
Returns:
Tuple (key, secret, token) of credentials used to authenticate with
the S3 filesystem.
"""
secret = self.get_authentication_secret(
expected_schema_type=AWSSecretSchema
)
if secret:
return (
secret.aws_access_key_id,
secret.aws_secret_access_key,
secret.aws_session_token,
)
else:
return self.key, self.secret, self.token
@property
def filesystem(self) -> s3fs.S3FileSystem:
"""The s3 filesystem to access this artifact store.
Returns:
The s3 filesystem.
"""
if not self._filesystem:
key, secret, token = self._get_credentials()
self._filesystem = s3fs.S3FileSystem(
key=key,
secret=secret,
token=token,
client_kwargs=self.client_kwargs,
config_kwargs=self.config_kwargs,
s3_additional_kwargs=self.s3_additional_kwargs,
)
return self._filesystem
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of paths that are files in the given directory.
"""
# remove s3 prefix if given, so we can remove the directory later as
# this method is expected to only return filenames
path = convert_to_str(path)
if path.startswith("s3://"):
path = path[5:]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the S3 filesystem.
Args:
file_dict: A file info dict returned by the S3 filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["Key"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
# s3fs.listdir also returns the root directory, so we filter
# it out here
if _extract_basename(dict_)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
A dictionary containing the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for directory, subdirectories, files in self.filesystem.walk(path=top):
yield f"s3://{directory}", subdirectories, files
filesystem: S3FileSystem
property
readonly
The s3 filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
S3FileSystem |
The s3 filesystem. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that are files in the given directory. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of paths that are files in the given directory.
"""
# remove s3 prefix if given, so we can remove the directory later as
# this method is expected to only return filenames
path = convert_to_str(path)
if path.startswith("s3://"):
path = path[5:]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the S3 filesystem.
Args:
file_dict: A file info dict returned by the S3 filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["Key"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
# s3fs.listdir also returns the root directory, so we filter
# it out here
if _extract_basename(dict_)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the file to remove. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the stat info. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
A dictionary containing the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for directory, subdirectories, files in self.filesystem.walk(path=top):
yield f"s3://{directory}", subdirectories, files
scipy
special
Initialization of the Scipy integration.
ScipyIntegration (Integration)
Definition of scipy integration for ZenML.
Source code in zenml/integrations/scipy/__init__.py
class ScipyIntegration(Integration):
"""Definition of scipy integration for ZenML."""
NAME = SCIPY
REQUIREMENTS = ["scipy"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.scipy import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/scipy/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.scipy import materializers # noqa
materializers
special
Initialization of the Scipy materializers.
sparse_materializer
Implementation of the Scipy Sparse Materializer.
SparseMaterializer (BaseMaterializer)
Materializer to read and write scipy sparse matrices.
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
class SparseMaterializer(BaseMaterializer):
"""Materializer to read and write scipy sparse matrices."""
ASSOCIATED_TYPES = (spmatrix,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> spmatrix:
"""Reads spmatrix from npz file.
Args:
data_type: The type of the spmatrix to load.
Returns:
A spmatrix object.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
) as f:
mat = load_npz(f)
return mat
def handle_return(self, mat: spmatrix) -> None:
"""Writes a spmatrix to the artifact store as a npz file.
Args:
mat: The spmatrix to write.
"""
super().handle_return(mat)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
) as f:
save_npz(f, mat)
handle_input(self, data_type)
Reads spmatrix from npz file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the spmatrix to load. |
required |
Returns:
Type | Description |
---|---|
spmatrix |
A spmatrix object. |
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_input(self, data_type: Type[Any]) -> spmatrix:
"""Reads spmatrix from npz file.
Args:
data_type: The type of the spmatrix to load.
Returns:
A spmatrix object.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
) as f:
mat = load_npz(f)
return mat
handle_return(self, mat)
Writes a spmatrix to the artifact store as a npz file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mat |
spmatrix |
The spmatrix to write. |
required |
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_return(self, mat: spmatrix) -> None:
"""Writes a spmatrix to the artifact store as a npz file.
Args:
mat: The spmatrix to write.
"""
super().handle_return(mat)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
) as f:
save_npz(f, mat)
seldon
special
Initialization of the Seldon integration.
The Seldon Core integration allows you to use the Seldon Core model serving platform to implement continuous model deployment.
SeldonIntegration (Integration)
Definition of Seldon Core integration for ZenML.
Source code in zenml/integrations/seldon/__init__.py
class SeldonIntegration(Integration):
"""Definition of Seldon Core integration for ZenML."""
NAME = SELDON
REQUIREMENTS = [
"kubernetes==18.20.0",
]
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.seldon import secret_schemas # noqa
from zenml.integrations.seldon import services # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Seldon Core.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=SELDON_MODEL_DEPLOYER_FLAVOR,
source="zenml.integrations.seldon.model_deployers.SeldonModelDeployer",
type=StackComponentType.MODEL_DEPLOYER,
integration=cls.NAME,
)
]
activate()
classmethod
Activate the Seldon Core integration.
Source code in zenml/integrations/seldon/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.seldon import secret_schemas # noqa
from zenml.integrations.seldon import services # noqa
flavors()
classmethod
Declare the stack component flavors for the Seldon Core.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/seldon/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Seldon Core.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=SELDON_MODEL_DEPLOYER_FLAVOR,
source="zenml.integrations.seldon.model_deployers.SeldonModelDeployer",
type=StackComponentType.MODEL_DEPLOYER,
integration=cls.NAME,
)
]
model_deployers
special
Initialization of the Seldon Model Deployer.
seldon_model_deployer
Implementation of the Seldon Model Deployer.
SeldonModelDeployer (BaseModelDeployer)
pydantic-model
Seldon Core model deployer stack component implementation.
Attributes:
Name | Type | Description |
---|---|---|
kubernetes_context |
Optional[str] |
the Kubernetes context to use to contact the remote Seldon Core installation. If not specified, the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod). |
kubernetes_namespace |
Optional[str] |
the Kubernetes namespace where the Seldon Core deployment servers are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod). |
base_url |
str |
the base URL of the Kubernetes ingress used to expose the Seldon Core deployment servers. |
secret |
Optional[str] |
the name of a ZenML secret containing the credentials used by Seldon Core storage initializers to authenticate to the Artifact Store (i.e. the storage backend where models are stored - see https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials). |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
class SeldonModelDeployer(BaseModelDeployer):
"""Seldon Core model deployer stack component implementation.
Attributes:
kubernetes_context: the Kubernetes context to use to contact the remote
Seldon Core installation. If not specified, the current
configuration is used. Depending on where the Seldon model deployer
is being used, this can be either a locally active context or an
in-cluster Kubernetes configuration (if running inside a pod).
kubernetes_namespace: the Kubernetes namespace where the Seldon Core
deployment servers are provisioned and managed by ZenML. If not
specified, the namespace set in the current configuration is used.
Depending on where the Seldon model deployer is being used, this can
be either the current namespace configured in the locally active
context or the namespace in the context of which the pod is running
(if running inside a pod).
base_url: the base URL of the Kubernetes ingress used to expose the
Seldon Core deployment servers.
secret: the name of a ZenML secret containing the credentials used by
Seldon Core storage initializers to authenticate to the Artifact
Store (i.e. the storage backend where models are stored - see
https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials).
"""
# Class Configuration
FLAVOR: ClassVar[str] = SELDON_MODEL_DEPLOYER_FLAVOR
kubernetes_context: Optional[str]
kubernetes_namespace: Optional[str]
base_url: str
secret: Optional[str]
# private attributes
_client: Optional[SeldonClient] = None
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "SeldonDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information that might be relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
Model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
}
@staticmethod
def get_active_model_deployer() -> "SeldonModelDeployer":
"""Get the Seldon Core model deployer registered in the active stack.
Returns:
The Seldon Core model deployer registered in the active stack.
Raises:
TypeError: if the Seldon Core model deployer is not available.
"""
model_deployer = Repository( # type: ignore [call-arg]
skip_repository_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, SeldonModelDeployer
):
raise TypeError(
f"The active stack needs to have a Seldon Core model deployer "
f"component registered to be able to deploy models with Seldon "
f"Core. You can create a new stack with a Seldon Core model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d seldon ...'\n"
)
return model_deployer
@property
def seldon_client(self) -> SeldonClient:
"""Get the Seldon Core client associated with this model deployer.
Returns:
The Seldon Core client.
"""
if not self._client:
self._client = SeldonClient(
context=self.kubernetes_context,
namespace=self.kubernetes_namespace,
)
return self._client
@property
def kubernetes_secret_name(self) -> Optional[str]:
"""Get the Kubernetes secret name associated with this model deployer.
If a secret is configured for this model deployer, a corresponding
Kubernetes secret is created in the remote cluster to be used
by Seldon Core storage initializers to authenticate to the Artifact
Store. This method returns the unique name that is used for this secret.
Returns:
The Seldon Core Kubernetes secret name, or None if no secret is
configured.
"""
if not self.secret:
return None
return (
re.sub(r"[^0-9a-zA-Z-]+", "-", f"zenml-seldon-core-{self.secret}")
.strip("-")
.lower()
)
def _create_or_update_kubernetes_secret(self) -> Optional[str]:
"""Create or update a Kubernetes secret.
Uses the information stored in the ZenML secret configured for the model deployer.
Returns:
The name of the Kubernetes secret that was created or updated, or
None if no secret was configured.
Raises:
RuntimeError: if the secret cannot be created or updated.
"""
# if a ZenML secret was configured in the model deployer,
# create a Kubernetes secret as a means to pass this information
# to the Seldon Core deployment
if self.secret:
secret_manager = Repository( # type: ignore [call-arg]
skip_repository_check=True
).active_stack.secrets_manager
if not secret_manager or not isinstance(
secret_manager, BaseSecretsManager
):
raise RuntimeError(
f"The active stack doesn't have a secret manager component. "
f"The ZenML secret specified in the Seldon Core Model "
f"Deployer configuration cannot be fetched: {self.secret}."
)
try:
zenml_secret = secret_manager.get_secret(self.secret)
except KeyError:
raise RuntimeError(
f"The ZenML secret '{self.secret}' specified in the "
f"Seldon Core Model Deployer configuration was not found "
f"in the active stack's secret manager."
)
# should never happen, just making mypy happy
assert self.kubernetes_secret_name is not None
self.seldon_client.create_or_update_secret(
self.kubernetes_secret_name, zenml_secret
)
return self.kubernetes_secret_name
def _delete_kubernetes_secret(self) -> None:
"""Delete the Kubernetes secret associated with this model deployer.
Do this if no Seldon Core deployments are using it.
"""
if self.kubernetes_secret_name:
# fetch all the Seldon Core deployments that currently
# configured to use this secret
services = self.find_model_server()
for service in services:
config = cast(SeldonDeploymentConfig, service.config)
if config.secret_name == self.kubernetes_secret_name:
return
self.seldon_client.delete_secret(self.kubernetes_secret_name)
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new Seldon Core deployment or update an existing one.
# noqa: DAR402
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new Seldon
Core deployment server to reflect the model and other configuration
parameters specified in the supplied Seldon deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing Seldon Core deployment that is *equivalent* to the supplied
configuration parameters. Two or more Seldon Core deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two Seldon Core deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent Seldon Core deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing Seldon Core deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new Seldon Core deployment
server for each new model version. If multiple equivalent Seldon Core
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with Seldon.
Core
replace: set this flag to True to find and update an equivalent
Seldon Core deployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the Seldon Core server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the Seldon Core
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML Seldon Core deployment service object that can be used to
interact with the remote Seldon Core server.
Raises:
SeldonClientError: if a Seldon Core client error is encountered
while provisioning the Seldon Core deployment server.
RuntimeError: if `timeout` is set to a positive value that is
exceeded while waiting for the Seldon Core deployment server
to start, or if an operational failure is encountered before
it reaches a ready state.
"""
config = cast(SeldonDeploymentConfig, config)
service = None
# if a custom Kubernetes secret is not explicitly specified in the
# SeldonDeploymentConfig, try to create one from the ZenML secret
# configured for the model deployer
config.secret_name = (
config.secret_name or self._create_or_update_kubernetes_secret()
)
# if replace is True, find equivalent Seldon Core deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing Seldon deployment service: {service}"
)
else:
# create a new service
service = SeldonDeploymentService(config=config)
logger.info(f"Creating a new Seldon deployment service: {service}")
# start the service which in turn provisions the Seldon Core
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
return service
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more Seldon Core model services that match the given criteria.
The Seldon Core deployment services that meet the search criteria are
returned sorted in descending order of their creation time (i.e. more
recent deployments first).
Args:
running: if true, only running services will be returned.
service_uuid: the UUID of the Seldon Core service that was originally used
to create the Seldon Core deployment resource.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
model_type: the Seldon Core server implementation used to serve
the model
Returns:
One or more Seldon Core service objects representing Seldon Core
model servers that match the input search criteria.
"""
# Use a Seldon deployment service configuration to compute the labels
config = SeldonDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_name=model_name or "",
model_uri=model_uri or "",
implementation=model_type or "",
)
labels = config.get_seldon_deployment_labels()
if service_uuid:
# the service UUID is not a label covered by the Seldon
# deployment service configuration, so we need to add it
# separately
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.seldon_client.find_deployments(labels=labels)
# sort the deployments in descending order of their creation time
deployments.sort(
key=lambda deployment: datetime.strptime(
deployment.metadata.creationTimestamp,
"%Y-%m-%dT%H:%M:%SZ",
)
if deployment.metadata.creationTimestamp
else datetime.min,
reverse=True,
)
services: List[BaseService] = []
for deployment in deployments:
# recreate the Seldon deployment service object from the Seldon
# deployment resource
service = SeldonDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a Seldon Core model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping Seldon Core model servers is not
supported.
"""
raise NotImplementedError(
"Stopping Seldon Core model servers is not implemented. Try "
"deleting the Seldon Core model server instead."
)
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting Seldon Core
model servers
"""
raise NotImplementedError(
"Starting Seldon Core model servers is not implemented"
)
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
# if this is the last Seldon Core model server, delete the Kubernetes
# secret used to store the authentication information for the Seldon
# Core model server storage initializer
self._delete_kubernetes_secret()
kubernetes_secret_name: Optional[str]
property
readonly
Get the Kubernetes secret name associated with this model deployer.
If a secret is configured for this model deployer, a corresponding Kubernetes secret is created in the remote cluster to be used by Seldon Core storage initializers to authenticate to the Artifact Store. This method returns the unique name that is used for this secret.
Returns:
Type | Description |
---|---|
Optional[str] |
The Seldon Core Kubernetes secret name, or None if no secret is configured. |
seldon_client: SeldonClient
property
readonly
Get the Seldon Core client associated with this model deployer.
Returns:
Type | Description |
---|---|
SeldonClient |
The Seldon Core client. |
delete_model_server(self, uuid, timeout=300, force=False)
Delete a Seldon Core model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to delete. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
# if this is the last Seldon Core model server, delete the Kubernetes
# secret used to store the authentication information for the Seldon
# Core model server storage initializer
self._delete_kubernetes_secret()
deploy_model(self, config, replace=False, timeout=300)
Create a new Seldon Core deployment or update an existing one.
noqa: DAR402
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new Seldon Core deployment server to reflect the model and other configuration parameters specified in the supplied Seldon deploymentconfig
. -
if
replace
is True, this method will first attempt to find an existing Seldon Core deployment that is equivalent to the supplied configuration parameters. Two or more Seldon Core deployments are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two Seldon Core deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent Seldon Core deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing Seldon Core deployment to retain its prediction URL while performing a rolling update to serve a new model version.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new Seldon Core deployment
server for each new model version. If multiple equivalent Seldon Core
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServiceConfig |
the configuration of the model to be deployed with Seldon. Core |
required |
replace |
bool |
set this flag to True to find and update an equivalent Seldon Core deployment server with the new model instead of starting a new deployment server. |
False |
timeout |
int |
the timeout in seconds to wait for the Seldon Core server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Seldon Core server is provisioned, without waiting for it to fully start. |
300 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML Seldon Core deployment service object that can be used to interact with the remote Seldon Core server. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if a Seldon Core client error is encountered while provisioning the Seldon Core deployment server. |
RuntimeError |
if |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new Seldon Core deployment or update an existing one.
# noqa: DAR402
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new Seldon
Core deployment server to reflect the model and other configuration
parameters specified in the supplied Seldon deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing Seldon Core deployment that is *equivalent* to the supplied
configuration parameters. Two or more Seldon Core deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two Seldon Core deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent Seldon Core deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing Seldon Core deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new Seldon Core deployment
server for each new model version. If multiple equivalent Seldon Core
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with Seldon.
Core
replace: set this flag to True to find and update an equivalent
Seldon Core deployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the Seldon Core server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the Seldon Core
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML Seldon Core deployment service object that can be used to
interact with the remote Seldon Core server.
Raises:
SeldonClientError: if a Seldon Core client error is encountered
while provisioning the Seldon Core deployment server.
RuntimeError: if `timeout` is set to a positive value that is
exceeded while waiting for the Seldon Core deployment server
to start, or if an operational failure is encountered before
it reaches a ready state.
"""
config = cast(SeldonDeploymentConfig, config)
service = None
# if a custom Kubernetes secret is not explicitly specified in the
# SeldonDeploymentConfig, try to create one from the ZenML secret
# configured for the model deployer
config.secret_name = (
config.secret_name or self._create_or_update_kubernetes_secret()
)
# if replace is True, find equivalent Seldon Core deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing Seldon deployment service: {service}"
)
else:
# create a new service
service = SeldonDeploymentService(config=config)
logger.info(f"Creating a new Seldon deployment service: {service}")
# start the service which in turn provisions the Seldon Core
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)
Find one or more Seldon Core model services that match the given criteria.
The Seldon Core deployment services that meet the search criteria are returned sorted in descending order of their creation time (i.e. more recent deployments first).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running |
bool |
if true, only running services will be returned. |
False |
service_uuid |
Optional[uuid.UUID] |
the UUID of the Seldon Core service that was originally used to create the Seldon Core deployment resource. |
None |
pipeline_name |
Optional[str] |
name of the pipeline that the deployed model was part of. |
None |
pipeline_run_id |
Optional[str] |
ID of the pipeline run which the deployed model was part of. |
None |
pipeline_step_name |
Optional[str] |
the name of the pipeline model deployment step that deployed the model. |
None |
model_name |
Optional[str] |
the name of the deployed model. |
None |
model_uri |
Optional[str] |
URI of the deployed model. |
None |
model_type |
Optional[str] |
the Seldon Core server implementation used to serve the model |
None |
Returns:
Type | Description |
---|---|
List[zenml.services.service.BaseService] |
One or more Seldon Core service objects representing Seldon Core model servers that match the input search criteria. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more Seldon Core model services that match the given criteria.
The Seldon Core deployment services that meet the search criteria are
returned sorted in descending order of their creation time (i.e. more
recent deployments first).
Args:
running: if true, only running services will be returned.
service_uuid: the UUID of the Seldon Core service that was originally used
to create the Seldon Core deployment resource.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
model_type: the Seldon Core server implementation used to serve
the model
Returns:
One or more Seldon Core service objects representing Seldon Core
model servers that match the input search criteria.
"""
# Use a Seldon deployment service configuration to compute the labels
config = SeldonDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_name=model_name or "",
model_uri=model_uri or "",
implementation=model_type or "",
)
labels = config.get_seldon_deployment_labels()
if service_uuid:
# the service UUID is not a label covered by the Seldon
# deployment service configuration, so we need to add it
# separately
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.seldon_client.find_deployments(labels=labels)
# sort the deployments in descending order of their creation time
deployments.sort(
key=lambda deployment: datetime.strptime(
deployment.metadata.creationTimestamp,
"%Y-%m-%dT%H:%M:%SZ",
)
if deployment.metadata.creationTimestamp
else datetime.min,
reverse=True,
)
services: List[BaseService] = []
for deployment in deployments:
# recreate the Seldon deployment service object from the Seldon
# deployment resource
service = SeldonDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
get_active_model_deployer()
staticmethod
Get the Seldon Core model deployer registered in the active stack.
Returns:
Type | Description |
---|---|
SeldonModelDeployer |
The Seldon Core model deployer registered in the active stack. |
Exceptions:
Type | Description |
---|---|
TypeError |
if the Seldon Core model deployer is not available. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "SeldonModelDeployer":
"""Get the Seldon Core model deployer registered in the active stack.
Returns:
The Seldon Core model deployer registered in the active stack.
Raises:
TypeError: if the Seldon Core model deployer is not available.
"""
model_deployer = Repository( # type: ignore [call-arg]
skip_repository_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, SeldonModelDeployer
):
raise TypeError(
f"The active stack needs to have a Seldon Core model deployer "
f"component registered to be able to deploy models with Seldon "
f"Core. You can create a new stack with a Seldon Core model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d seldon ...'\n"
)
return model_deployer
get_model_server_info(service_instance)
staticmethod
Return implementation specific information that might be relevant to the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
SeldonDeploymentService |
Instance of a SeldonDeploymentService |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
Model server information. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "SeldonDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information that might be relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
Model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
}
start_model_server(self, uuid, timeout=300)
Start a Seldon Core model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to start. |
required |
timeout |
int |
timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active. |
300 |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
since we don't support starting Seldon Core model servers |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting Seldon Core
model servers
"""
raise NotImplementedError(
"Starting Seldon Core model servers is not implemented"
)
stop_model_server(self, uuid, timeout=300, force=False)
Stop a Seldon Core model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to stop. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
stopping Seldon Core model servers is not supported. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a Seldon Core model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping Seldon Core model servers is not
supported.
"""
raise NotImplementedError(
"Stopping Seldon Core model servers is not implemented. Try "
"deleting the Seldon Core model server instead."
)
secret_schemas
special
Initialization for the Seldon secret schemas.
These are secret schemas that can be used to authenticate Seldon to the Artifact Store used to store served ML models.
secret_schemas
Implementation for Seldon secret schemas.
SeldonAzureSecretSchema (BaseSecretSchema)
pydantic-model
Seldon Azure Blob Storage credentials.
Based on: https://rclone.org/azureblob/
Attributes:
Name | Type | Description |
---|---|---|
rclone_config_azureblob_type |
Literal['azureblob'] |
the rclone config type. Must be set to "azureblob" for this schema. |
rclone_config_azureblob_account |
Optional[str] |
storage Account Name. Leave blank to use SAS URL or MSI. |
rclone_config_azureblob_key |
Optional[str] |
storage Account Key. Leave blank to use SAS URL or MSI. |
rclone_config_azureblob_sas_url |
Optional[str] |
SAS URL for container level access only. Leave blank if using account/key or MSI. |
rclone_config_azureblob_use_msi |
bool |
use a managed service identity to authenticate (only works in Azure). |
Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonAzureSecretSchema(BaseSecretSchema):
"""Seldon Azure Blob Storage credentials.
Based on: https://rclone.org/azureblob/
Attributes:
rclone_config_azureblob_type: the rclone config type. Must be set to
"azureblob" for this schema.
rclone_config_azureblob_account: storage Account Name. Leave blank to
use SAS URL or MSI.
rclone_config_azureblob_key: storage Account Key. Leave blank to
use SAS URL or MSI.
rclone_config_azureblob_sas_url: SAS URL for container level access
only. Leave blank if using account/key or MSI.
rclone_config_azureblob_use_msi: use a managed service identity to
authenticate (only works in Azure).
"""
TYPE: ClassVar[str] = SELDON_AZUREBLOB_SECRET_SCHEMA_TYPE
rclone_config_azureblob_type: Literal["azureblob"] = "azureblob"
rclone_config_azureblob_account: Optional[str]
rclone_config_azureblob_key: Optional[str]
rclone_config_azureblob_sas_url: Optional[str]
rclone_config_azureblob_use_msi: bool = False
SeldonGSSecretSchema (BaseSecretSchema)
pydantic-model
Seldon GCS credentials.
Based on: https://rclone.org/googlecloudstorage/
Attributes:
Name | Type | Description |
---|---|---|
rclone_config_gs_type |
Literal['google cloud storage'] |
the rclone config type. Must be set to "google cloud storage" for this schema. |
rclone_config_gs_client_id |
Optional[str] |
OAuth client id. |
rclone_config_gs_client_secret |
Optional[str] |
OAuth client secret. |
rclone_config_gs_token |
Optional[str] |
OAuth Access Token as a JSON blob. |
rclone_config_gs_project_number |
Optional[str] |
project number. |
rclone_config_gs_service_account_credentials |
Optional[str] |
service account credentials JSON blob. |
rclone_config_gs_anonymous |
bool |
access public buckets and objects without credentials. Set to True if you just want to download files and don't configure credentials. |
rclone_config_gs_auth_url |
Optional[str] |
auth server URL. |
Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonGSSecretSchema(BaseSecretSchema):
"""Seldon GCS credentials.
Based on: https://rclone.org/googlecloudstorage/
Attributes:
rclone_config_gs_type: the rclone config type. Must be set to "google
cloud storage" for this schema.
rclone_config_gs_client_id: OAuth client id.
rclone_config_gs_client_secret: OAuth client secret.
rclone_config_gs_token: OAuth Access Token as a JSON blob.
rclone_config_gs_project_number: project number.
rclone_config_gs_service_account_credentials: service account
credentials JSON blob.
rclone_config_gs_anonymous: access public buckets and objects without
credentials. Set to True if you just want to download files and
don't configure credentials.
rclone_config_gs_auth_url: auth server URL.
"""
TYPE: ClassVar[str] = SELDON_GS_SECRET_SCHEMA_TYPE
rclone_config_gs_type: Literal[
"google cloud storage"
] = "google cloud storage"
rclone_config_gs_client_id: Optional[str]
rclone_config_gs_client_secret: Optional[str]
rclone_config_gs_project_number: Optional[str]
rclone_config_gs_service_account_credentials: Optional[str]
rclone_config_gs_anonymous: bool = False
rclone_config_gs_token: Optional[str]
rclone_config_gs_auth_url: Optional[str]
rclone_config_gs_token_url: Optional[str]
SeldonS3SecretSchema (BaseSecretSchema)
pydantic-model
Seldon S3 credentials.
Based on: https://rclone.org/s3/#amazon-s3
Attributes:
Name | Type | Description |
---|---|---|
rclone_config_s3_type |
Literal['s3'] |
the rclone config type. Must be set to "s3" for this schema. |
rclone_config_s3_provider |
str |
the S3 provider (e.g. aws, ceph, minio). |
rclone_config_s3_env_auth |
bool |
get AWS credentials from EC2/ECS meta data (i.e. with IAM roles configuration). Only applies if access_key_id and secret_access_key are blank. |
rclone_config_s3_access_key_id |
Optional[str] |
AWS Access Key ID. |
rclone_config_s3_secret_access_key |
Optional[str] |
AWS Secret Access Key. |
rclone_config_s3_session_token |
Optional[str] |
AWS Session Token. |
rclone_config_s3_region |
Optional[str] |
region to connect to. |
rclone_config_s3_endpoint |
Optional[str] |
S3 API endpoint. |
Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonS3SecretSchema(BaseSecretSchema):
"""Seldon S3 credentials.
Based on: https://rclone.org/s3/#amazon-s3
Attributes:
rclone_config_s3_type: the rclone config type. Must be set to "s3" for
this schema.
rclone_config_s3_provider: the S3 provider (e.g. aws, ceph, minio).
rclone_config_s3_env_auth: get AWS credentials from EC2/ECS meta data
(i.e. with IAM roles configuration). Only applies if access_key_id
and secret_access_key are blank.
rclone_config_s3_access_key_id: AWS Access Key ID.
rclone_config_s3_secret_access_key: AWS Secret Access Key.
rclone_config_s3_session_token: AWS Session Token.
rclone_config_s3_region: region to connect to.
rclone_config_s3_endpoint: S3 API endpoint.
"""
TYPE: ClassVar[str] = SELDON_S3_SECRET_SCHEMA_TYPE
rclone_config_s3_type: Literal["s3"] = "s3"
rclone_config_s3_provider: str = "aws"
rclone_config_s3_env_auth: bool = False
rclone_config_s3_access_key_id: Optional[str]
rclone_config_s3_secret_access_key: Optional[str]
rclone_config_s3_session_token: Optional[str]
rclone_config_s3_region: Optional[str]
rclone_config_s3_endpoint: Optional[str]
seldon_client
Implementation of the Seldon client for ZenML.
SeldonClient
A client for interacting with Seldon Deployments.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClient:
"""A client for interacting with Seldon Deployments."""
def __init__(self, context: Optional[str], namespace: Optional[str]):
"""Initialize a Seldon Core client.
Args:
context: the Kubernetes context to use.
namespace: the Kubernetes namespace to use.
"""
self._context = context
self._namespace = namespace
self._initialize_k8s_clients()
def _initialize_k8s_clients(self) -> None:
"""Initialize the Kubernetes clients.
Raises:
SeldonClientError: if Kubernetes configuration could not be loaded
"""
try:
k8s_config.load_incluster_config()
if not self._namespace:
# load the namespace in the context of which the
# current pod is running
self._namespace = open(
"/var/run/secrets/kubernetes.io/serviceaccount/namespace"
).read()
except k8s_config.config_exception.ConfigException:
if not self._namespace:
raise SeldonClientError(
"The Kubernetes namespace must be explicitly "
"configured when running outside of a cluster."
)
try:
k8s_config.load_kube_config(
context=self._context, persist_config=False
)
except k8s_config.config_exception.ConfigException as e:
raise SeldonClientError(
"Could not load the Kubernetes configuration"
) from e
self._core_api = k8s_client.CoreV1Api()
self._custom_objects_api = k8s_client.CustomObjectsApi()
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See:
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: the labels to sanitize.
"""
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
@property
def namespace(self) -> str:
"""Returns the Kubernetes namespace in use by the client.
Returns:
The Kubernetes namespace in use by the client.
Raises:
RuntimeError: if the namespace has not been configured.
"""
if not self._namespace:
# shouldn't happen if the client is initialized, but we need to
# appease the mypy type checker
raise RuntimeError("The Kubernetes namespace is not configured")
return self._namespace
def create_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Create a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to create
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the created Seldon Core deployment resource with updated status.
Raises:
SeldonDeploymentExistsError: if a deployment with the same name
already exists.
SeldonClientError: if an unknown error occurs during the creation of
the deployment.
"""
try:
logger.debug(f"Creating SeldonDeployment resource: {deployment}")
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
response = self._custom_objects_api.create_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
body=deployment.dict(exclude_none=True),
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when creating SeldonDeployment resource: %s", str(e)
)
if e.status == 409:
raise SeldonDeploymentExistsError(
f"A deployment with the name {deployment.name} "
f"already exists in namespace {self._namespace}"
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
created_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and created_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
created_deployment = self.get_deployment(name=deployment.name)
return created_deployment
def delete_deployment(
self,
name: str,
force: bool = False,
poll_timeout: int = 0,
) -> None:
"""Delete a Seldon Core deployment resource managed by ZenML.
Args:
name: the name of the Seldon Core deployment resource to delete.
force: if True, the deployment deletion will be forced (the graceful
period will be set to zero).
poll_timeout: the maximum time to wait for the deployment to be
deleted. If set to 0, the function will return immediately
without checking the deployment status. If a timeout
occurs and the deployment still exists, this method will
return and no exception will be raised.
Raises:
SeldonClientError: if an unknown error occurs during the deployment
removal.
"""
try:
logger.debug(f"Deleting SeldonDeployment resource: {name}")
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=name)
response = self._custom_objects_api.delete_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
_request_timeout=poll_timeout or None,
grace_period_seconds=0 if force else None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when deleting SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting SeldonDeployment resource {name}"
) from e
while poll_timeout > 0:
try:
self.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return
time.sleep(5)
poll_timeout -= 5
def update_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Update a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to update
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the updated Seldon Core deployment resource with updated status.
Raises:
SeldonClientError: if an unknown error occurs while updating the
deployment.
"""
try:
logger.debug(
f"Updating SeldonDeployment resource: {deployment.name}"
)
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=deployment.name)
response = self._custom_objects_api.patch_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=deployment.name,
body=deployment.dict(exclude_none=True),
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when updating SeldonDeployment resource: %s", str(e)
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
updated_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and updated_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
updated_deployment = self.get_deployment(name=deployment.name)
return updated_deployment
def get_deployment(self, name: str) -> SeldonDeployment:
"""Get a ZenML managed Seldon Core deployment resource by name.
Args:
name: the name of the Seldon Core deployment resource to fetch.
Returns:
The Seldon Core deployment resource.
Raises:
SeldonDeploymentNotFoundError: if the deployment resource cannot
be found or is not managed by ZenML.
SeldonClientError: if an unknown error occurs while fetching
the deployment.
"""
try:
logger.debug(f"Retrieving SeldonDeployment resource: {name}")
response = self._custom_objects_api.get_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
)
logger.debug("Seldon Core API response: %s", response)
try:
deployment = SeldonDeployment(**response)
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(response),
)
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource {name} could not be parsed"
)
# Only Seldon deployments managed by ZenML are returned
if not deployment.is_managed_by_zenml():
raise SeldonDeploymentNotFoundError(
f"Seldon Deployment {name} is not managed by ZenML"
)
return deployment
except k8s_client.rest.ApiException as e:
if e.status == 404:
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource not found: {name}"
) from e
logger.error(
"Exception when fetching SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching SeldonDeployment "
f"resource: {name}"
) from e
def find_deployments(
self,
name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
fields: Optional[Dict[str, str]] = None,
) -> List[SeldonDeployment]:
"""Find all ZenML-managed Seldon Core deployment resources matching the given criteria.
Args:
name: optional name of the deployment resource to find.
fields: optional selector to restrict the list of returned
Seldon deployments by their fields. Defaults to everything.
labels: optional selector to restrict the list of returned
Seldon deployments by their labels. Defaults to everything.
Returns:
List of Seldon Core deployments that match the given criteria.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the deployments.
"""
fields = fields or {}
labels = labels or {}
# always filter results to only include Seldon deployments managed
# by ZenML
labels["app"] = "zenml"
if name:
fields = {"metadata.name": name}
field_selector = (
",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
)
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
try:
logger.debug(
f"Searching SeldonDeployment resources with label selector "
f"'{labels or ''}' and field selector '{fields or ''}'"
)
response = self._custom_objects_api.list_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
field_selector=field_selector,
label_selector=label_selector,
)
logger.debug(
"Seldon Core API returned %s items", len(response["items"])
)
deployments = []
for item in response.get("items") or []:
try:
deployments.append(SeldonDeployment(**item))
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(item),
)
return deployments
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when searching SeldonDeployment resources with "
"label selector '%s' and field selector '%s': %s",
label_selector or "",
field_selector or "",
)
raise SeldonClientError(
f"Unexpected exception when searching SeldonDeployment "
f"with labels '{labels or ''}' and field '{fields or ''}'"
) from e
def get_deployment_logs(
self,
name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core deployment resource.
Args:
name: the name of the Seldon Core deployment to get logs for.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Yields:
The next log line.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the logs.
"""
logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
try:
response = self._core_api.list_namespaced_pod(
namespace=self._namespace,
label_selector=f"seldon-deployment-id={name}",
)
logger.debug("Kubernetes API response: %s", response)
pods = response.items
if not pods:
raise SeldonClientError(
f"The Seldon Core deployment {name} is not currently "
f"running: no Kubernetes pods associated with it were found"
)
pod = pods[0]
pod_name = pod.metadata.name
containers = [c.name for c in pod.spec.containers]
init_containers = [c.name for c in pod.spec.init_containers]
container_statuses = {
c.name: c.started or c.restart_count
for c in pod.status.container_statuses
}
container = "default"
if container not in containers:
container = containers[0]
# some containers might not be running yet and have no logs to show,
# so we need to filter them out
if not container_statuses[container]:
container = init_containers[0]
logger.info(
f"Retrieving logs for pod: `{pod_name}` and container "
f"`{container}` in namespace `{self._namespace}`"
)
response = self._core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self._namespace,
container=container,
follow=follow,
tail_lines=tail,
_preload_content=False,
)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when fetching logs for SeldonDeployment resource "
"%s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching logs for SeldonDeployment "
f"resource: {name}"
) from e
try:
while True:
line = response.readline().decode("utf-8").rstrip("\n")
if not line:
return
stop = yield line
if stop:
return
finally:
response.release_conn()
def create_or_update_secret(
self,
name: str,
secret: BaseSecretSchema,
) -> None:
"""Create or update a Kubernetes Secret resource.
Uses the information contained in a ZenML secret.
Args:
name: the name of the Secret resource to create.
secret: a ZenML secret with key-values that should be
stored in the Secret resource.
Raises:
SeldonClientError: if an unknown error occurs during the creation of
the secret.
k8s_client.rest.ApiException: unexpected error.
"""
try:
logger.debug(f"Creating Secret resource: {name}")
secret_data = {
k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
"ascii"
)
for k, v in secret.content.items()
if v is not None
}
secret = k8s_client.V1Secret(
metadata=k8s_client.V1ObjectMeta(
name=name,
labels={"app": "zenml"},
),
type="Opaque",
data=secret_data,
)
try:
# check if the secret is already present
self._core_api.read_namespaced_secret(
name=name,
namespace=self._namespace,
)
# if we got this far, the secret is already present, update it
# in place
response = self._core_api.replace_namespaced_secret(
name=name,
namespace=self._namespace,
body=secret,
)
except k8s_client.rest.ApiException as e:
if e.status != 404:
# if an error other than 404 is raised here, treat it
# as an unexpected error
raise
response = self._core_api.create_namespaced_secret(
namespace=self._namespace,
body=secret,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error("Exception when creating Secret resource: %s", str(e))
raise SeldonClientError(
"Exception when creating Secret resource"
) from e
def delete_secret(
self,
name: str,
) -> None:
"""Delete a Kubernetes Secret resource managed by ZenML.
Args:
name: the name of the Kubernetes Secret resource to delete.
Raises:
SeldonClientError: if an unknown error occurs during the removal
of the secret.
"""
try:
logger.debug(f"Deleting Secret resource: {name}")
response = self._core_api.delete_namespaced_secret(
name=name,
namespace=self._namespace,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
if e.status == 404:
# the secret is no longer present, nothing to do
return
logger.error(
"Exception when deleting Secret resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting Secret resource {name}"
) from e
namespace: str
property
readonly
Returns the Kubernetes namespace in use by the client.
Returns:
Type | Description |
---|---|
str |
The Kubernetes namespace in use by the client. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the namespace has not been configured. |
__init__(self, context, namespace)
special
Initialize a Seldon Core client.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
Optional[str] |
the Kubernetes context to use. |
required |
namespace |
Optional[str] |
the Kubernetes namespace to use. |
required |
Source code in zenml/integrations/seldon/seldon_client.py
def __init__(self, context: Optional[str], namespace: Optional[str]):
"""Initialize a Seldon Core client.
Args:
context: the Kubernetes context to use.
namespace: the Kubernetes namespace to use.
"""
self._context = context
self._namespace = namespace
self._initialize_k8s_clients()
create_deployment(self, deployment, poll_timeout=0)
Create a Seldon Core deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource to create |
required |
poll_timeout |
int |
the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised. |
0 |
Returns:
Type | Description |
---|---|
SeldonDeployment |
the created Seldon Core deployment resource with updated status. |
Exceptions:
Type | Description |
---|---|
SeldonDeploymentExistsError |
if a deployment with the same name already exists. |
SeldonClientError |
if an unknown error occurs during the creation of the deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def create_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Create a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to create
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the created Seldon Core deployment resource with updated status.
Raises:
SeldonDeploymentExistsError: if a deployment with the same name
already exists.
SeldonClientError: if an unknown error occurs during the creation of
the deployment.
"""
try:
logger.debug(f"Creating SeldonDeployment resource: {deployment}")
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
response = self._custom_objects_api.create_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
body=deployment.dict(exclude_none=True),
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when creating SeldonDeployment resource: %s", str(e)
)
if e.status == 409:
raise SeldonDeploymentExistsError(
f"A deployment with the name {deployment.name} "
f"already exists in namespace {self._namespace}"
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
created_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and created_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
created_deployment = self.get_deployment(name=deployment.name)
return created_deployment
create_or_update_secret(self, name, secret)
Create or update a Kubernetes Secret resource.
Uses the information contained in a ZenML secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Secret resource to create. |
required |
secret |
BaseSecretSchema |
a ZenML secret with key-values that should be stored in the Secret resource. |
required |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs during the creation of the secret. |
k8s_client.rest.ApiException |
unexpected error. |
Source code in zenml/integrations/seldon/seldon_client.py
def create_or_update_secret(
self,
name: str,
secret: BaseSecretSchema,
) -> None:
"""Create or update a Kubernetes Secret resource.
Uses the information contained in a ZenML secret.
Args:
name: the name of the Secret resource to create.
secret: a ZenML secret with key-values that should be
stored in the Secret resource.
Raises:
SeldonClientError: if an unknown error occurs during the creation of
the secret.
k8s_client.rest.ApiException: unexpected error.
"""
try:
logger.debug(f"Creating Secret resource: {name}")
secret_data = {
k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
"ascii"
)
for k, v in secret.content.items()
if v is not None
}
secret = k8s_client.V1Secret(
metadata=k8s_client.V1ObjectMeta(
name=name,
labels={"app": "zenml"},
),
type="Opaque",
data=secret_data,
)
try:
# check if the secret is already present
self._core_api.read_namespaced_secret(
name=name,
namespace=self._namespace,
)
# if we got this far, the secret is already present, update it
# in place
response = self._core_api.replace_namespaced_secret(
name=name,
namespace=self._namespace,
body=secret,
)
except k8s_client.rest.ApiException as e:
if e.status != 404:
# if an error other than 404 is raised here, treat it
# as an unexpected error
raise
response = self._core_api.create_namespaced_secret(
namespace=self._namespace,
body=secret,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error("Exception when creating Secret resource: %s", str(e))
raise SeldonClientError(
"Exception when creating Secret resource"
) from e
delete_deployment(self, name, force=False, poll_timeout=0)
Delete a Seldon Core deployment resource managed by ZenML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Seldon Core deployment resource to delete. |
required |
force |
bool |
if True, the deployment deletion will be forced (the graceful period will be set to zero). |
False |
poll_timeout |
int |
the maximum time to wait for the deployment to be deleted. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment still exists, this method will return and no exception will be raised. |
0 |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs during the deployment removal. |
Source code in zenml/integrations/seldon/seldon_client.py
def delete_deployment(
self,
name: str,
force: bool = False,
poll_timeout: int = 0,
) -> None:
"""Delete a Seldon Core deployment resource managed by ZenML.
Args:
name: the name of the Seldon Core deployment resource to delete.
force: if True, the deployment deletion will be forced (the graceful
period will be set to zero).
poll_timeout: the maximum time to wait for the deployment to be
deleted. If set to 0, the function will return immediately
without checking the deployment status. If a timeout
occurs and the deployment still exists, this method will
return and no exception will be raised.
Raises:
SeldonClientError: if an unknown error occurs during the deployment
removal.
"""
try:
logger.debug(f"Deleting SeldonDeployment resource: {name}")
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=name)
response = self._custom_objects_api.delete_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
_request_timeout=poll_timeout or None,
grace_period_seconds=0 if force else None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when deleting SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting SeldonDeployment resource {name}"
) from e
while poll_timeout > 0:
try:
self.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return
time.sleep(5)
poll_timeout -= 5
delete_secret(self, name)
Delete a Kubernetes Secret resource managed by ZenML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Kubernetes Secret resource to delete. |
required |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs during the removal of the secret. |
Source code in zenml/integrations/seldon/seldon_client.py
def delete_secret(
self,
name: str,
) -> None:
"""Delete a Kubernetes Secret resource managed by ZenML.
Args:
name: the name of the Kubernetes Secret resource to delete.
Raises:
SeldonClientError: if an unknown error occurs during the removal
of the secret.
"""
try:
logger.debug(f"Deleting Secret resource: {name}")
response = self._core_api.delete_namespaced_secret(
name=name,
namespace=self._namespace,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
if e.status == 404:
# the secret is no longer present, nothing to do
return
logger.error(
"Exception when deleting Secret resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting Secret resource {name}"
) from e
find_deployments(self, name=None, labels=None, fields=None)
Find all ZenML-managed Seldon Core deployment resources matching the given criteria.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
optional name of the deployment resource to find. |
None |
fields |
Optional[Dict[str, str]] |
optional selector to restrict the list of returned Seldon deployments by their fields. Defaults to everything. |
None |
labels |
Optional[Dict[str, str]] |
optional selector to restrict the list of returned Seldon deployments by their labels. Defaults to everything. |
None |
Returns:
Type | Description |
---|---|
List[zenml.integrations.seldon.seldon_client.SeldonDeployment] |
List of Seldon Core deployments that match the given criteria. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs while fetching the deployments. |
Source code in zenml/integrations/seldon/seldon_client.py
def find_deployments(
self,
name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
fields: Optional[Dict[str, str]] = None,
) -> List[SeldonDeployment]:
"""Find all ZenML-managed Seldon Core deployment resources matching the given criteria.
Args:
name: optional name of the deployment resource to find.
fields: optional selector to restrict the list of returned
Seldon deployments by their fields. Defaults to everything.
labels: optional selector to restrict the list of returned
Seldon deployments by their labels. Defaults to everything.
Returns:
List of Seldon Core deployments that match the given criteria.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the deployments.
"""
fields = fields or {}
labels = labels or {}
# always filter results to only include Seldon deployments managed
# by ZenML
labels["app"] = "zenml"
if name:
fields = {"metadata.name": name}
field_selector = (
",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
)
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
try:
logger.debug(
f"Searching SeldonDeployment resources with label selector "
f"'{labels or ''}' and field selector '{fields or ''}'"
)
response = self._custom_objects_api.list_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
field_selector=field_selector,
label_selector=label_selector,
)
logger.debug(
"Seldon Core API returned %s items", len(response["items"])
)
deployments = []
for item in response.get("items") or []:
try:
deployments.append(SeldonDeployment(**item))
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(item),
)
return deployments
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when searching SeldonDeployment resources with "
"label selector '%s' and field selector '%s': %s",
label_selector or "",
field_selector or "",
)
raise SeldonClientError(
f"Unexpected exception when searching SeldonDeployment "
f"with labels '{labels or ''}' and field '{fields or ''}'"
) from e
get_deployment(self, name)
Get a ZenML managed Seldon Core deployment resource by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Seldon Core deployment resource to fetch. |
required |
Returns:
Type | Description |
---|---|
SeldonDeployment |
The Seldon Core deployment resource. |
Exceptions:
Type | Description |
---|---|
SeldonDeploymentNotFoundError |
if the deployment resource cannot be found or is not managed by ZenML. |
SeldonClientError |
if an unknown error occurs while fetching the deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment(self, name: str) -> SeldonDeployment:
"""Get a ZenML managed Seldon Core deployment resource by name.
Args:
name: the name of the Seldon Core deployment resource to fetch.
Returns:
The Seldon Core deployment resource.
Raises:
SeldonDeploymentNotFoundError: if the deployment resource cannot
be found or is not managed by ZenML.
SeldonClientError: if an unknown error occurs while fetching
the deployment.
"""
try:
logger.debug(f"Retrieving SeldonDeployment resource: {name}")
response = self._custom_objects_api.get_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
)
logger.debug("Seldon Core API response: %s", response)
try:
deployment = SeldonDeployment(**response)
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(response),
)
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource {name} could not be parsed"
)
# Only Seldon deployments managed by ZenML are returned
if not deployment.is_managed_by_zenml():
raise SeldonDeploymentNotFoundError(
f"Seldon Deployment {name} is not managed by ZenML"
)
return deployment
except k8s_client.rest.ApiException as e:
if e.status == 404:
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource not found: {name}"
) from e
logger.error(
"Exception when fetching SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching SeldonDeployment "
f"resource: {name}"
) from e
get_deployment_logs(self, name, follow=False, tail=None)
Get the logs of a Seldon Core deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Seldon Core deployment to get logs for. |
required |
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Yields:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
The next log line. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs while fetching the logs. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment_logs(
self,
name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core deployment resource.
Args:
name: the name of the Seldon Core deployment to get logs for.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Yields:
The next log line.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the logs.
"""
logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
try:
response = self._core_api.list_namespaced_pod(
namespace=self._namespace,
label_selector=f"seldon-deployment-id={name}",
)
logger.debug("Kubernetes API response: %s", response)
pods = response.items
if not pods:
raise SeldonClientError(
f"The Seldon Core deployment {name} is not currently "
f"running: no Kubernetes pods associated with it were found"
)
pod = pods[0]
pod_name = pod.metadata.name
containers = [c.name for c in pod.spec.containers]
init_containers = [c.name for c in pod.spec.init_containers]
container_statuses = {
c.name: c.started or c.restart_count
for c in pod.status.container_statuses
}
container = "default"
if container not in containers:
container = containers[0]
# some containers might not be running yet and have no logs to show,
# so we need to filter them out
if not container_statuses[container]:
container = init_containers[0]
logger.info(
f"Retrieving logs for pod: `{pod_name}` and container "
f"`{container}` in namespace `{self._namespace}`"
)
response = self._core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self._namespace,
container=container,
follow=follow,
tail_lines=tail,
_preload_content=False,
)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when fetching logs for SeldonDeployment resource "
"%s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching logs for SeldonDeployment "
f"resource: {name}"
) from e
try:
while True:
line = response.readline().decode("utf-8").rstrip("\n")
if not line:
return
stop = yield line
if stop:
return
finally:
response.release_conn()
sanitize_labels(labels)
staticmethod
Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
the labels to sanitize. |
required |
Source code in zenml/integrations/seldon/seldon_client.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See:
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: the labels to sanitize.
"""
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
update_deployment(self, deployment, poll_timeout=0)
Update a Seldon Core deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource to update |
required |
poll_timeout |
int |
the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised. |
0 |
Returns:
Type | Description |
---|---|
SeldonDeployment |
the updated Seldon Core deployment resource with updated status. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs while updating the deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def update_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Update a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to update
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the updated Seldon Core deployment resource with updated status.
Raises:
SeldonClientError: if an unknown error occurs while updating the
deployment.
"""
try:
logger.debug(
f"Updating SeldonDeployment resource: {deployment.name}"
)
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=deployment.name)
response = self._custom_objects_api.patch_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=deployment.name,
body=deployment.dict(exclude_none=True),
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when updating SeldonDeployment resource: %s", str(e)
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
updated_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and updated_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
updated_deployment = self.get_deployment(name=deployment.name)
return updated_deployment
SeldonClientError (Exception)
Base exception class for all exceptions raised by the SeldonClient.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientError(Exception):
"""Base exception class for all exceptions raised by the SeldonClient."""
SeldonClientTimeout (SeldonClientError)
Raised when the Seldon client timed out while waiting for a resource to reach the expected status.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientTimeout(SeldonClientError):
"""Raised when the Seldon client timed out while waiting for a resource to reach the expected status."""
SeldonDeployment (BaseModel)
pydantic-model
A Seldon Core deployment CRD.
This is a Pydantic representation of some of the fields in the Seldon Core CRD (documented here: https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).
Note that not all fields are represented, only those that are relevant to the ZenML integration. The fields that are not represented are silently ignored when the Seldon Deployment is created or updated from an external SeldonDeployment CRD representation.
Attributes:
Name | Type | Description |
---|---|---|
kind |
str |
Kubernetes kind field. |
apiVersion |
str |
Kubernetes apiVersion field. |
metadata |
SeldonDeploymentMetadata |
Kubernetes metadata field. |
spec |
SeldonDeploymentSpec |
Seldon Deployment spec entry. |
status |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatus] |
Seldon Deployment status. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeployment(BaseModel):
"""A Seldon Core deployment CRD.
This is a Pydantic representation of some of the fields in the Seldon Core
CRD (documented here:
https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).
Note that not all fields are represented, only those that are relevant to
the ZenML integration. The fields that are not represented are silently
ignored when the Seldon Deployment is created or updated from an external
SeldonDeployment CRD representation.
Attributes:
kind: Kubernetes kind field.
apiVersion: Kubernetes apiVersion field.
metadata: Kubernetes metadata field.
spec: Seldon Deployment spec entry.
status: Seldon Deployment status.
"""
kind: str = Field(SELDON_DEPLOYMENT_KIND, const=True)
apiVersion: str = Field(SELDON_DEPLOYMENT_API_VERSION, const=True)
metadata: SeldonDeploymentMetadata = Field(
default_factory=SeldonDeploymentMetadata
)
spec: SeldonDeploymentSpec = Field(default_factory=SeldonDeploymentSpec)
status: Optional[SeldonDeploymentStatus]
def __str__(self) -> str:
"""Returns a string representation of the Seldon Deployment.
Returns:
A string representation of the Seldon Deployment.
"""
return json.dumps(self.dict(exclude_none=True), indent=4)
@classmethod
def build(
cls,
name: Optional[str] = None,
model_uri: Optional[str] = None,
model_name: Optional[str] = None,
implementation: Optional[str] = None,
secret_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
) -> "SeldonDeployment":
"""Build a basic Seldon Deployment object.
Args:
name: The name of the Seldon Deployment. If not explicitly passed,
a unique name is autogenerated.
model_uri: The URI of the model.
model_name: The name of the model.
implementation: The implementation of the model.
secret_name: The name of the Kubernetes secret containing
environment variable values (e.g. with credentials for the
artifact store) to use with the deployment service.
labels: A dictionary of labels to apply to the Seldon Deployment.
annotations: A dictionary of annotations to apply to the Seldon
Deployment.
Returns:
A minimal SeldonDeployment object built from the provided
parameters.
"""
if not name:
name = f"zenml-{time.time()}"
if labels is None:
labels = {}
if annotations is None:
annotations = {}
return SeldonDeployment(
metadata=SeldonDeploymentMetadata(
name=name, labels=labels, annotations=annotations
),
spec=SeldonDeploymentSpec(
name=name,
predictors=[
SeldonDeploymentPredictor(
name=model_name or "",
graph=SeldonDeploymentPredictiveUnit(
name="default",
type=SeldonDeploymentPredictiveUnitType.MODEL,
modelUri=model_uri or "",
implementation=implementation or "",
envSecretRefName=secret_name,
),
)
],
),
)
def is_managed_by_zenml(self) -> bool:
"""Checks if this Seldon Deployment is managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
Returns:
True if the Seldon Deployment is managed by ZenML, False
otherwise.
"""
return self.metadata.labels.get("app") == "zenml"
def mark_as_managed_by_zenml(self) -> None:
"""Marks this Seldon Deployment as managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
"""
self.metadata.labels["app"] = "zenml"
@property
def name(self) -> str:
"""Returns the name of this Seldon Deployment.
This is just a shortcut for `self.metadata.name`.
Returns:
The name of this Seldon Deployment.
"""
return self.metadata.name
@property
def state(self) -> SeldonDeploymentStatusState:
"""The state of the Seldon Deployment.
Returns:
The state of the Seldon Deployment.
"""
if not self.status:
return SeldonDeploymentStatusState.UNKNOWN
return self.status.state
def is_pending(self) -> bool:
"""Checks if the Seldon Deployment is in a pending state.
Returns:
True if the Seldon Deployment is pending, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.CREATING
def is_available(self) -> bool:
"""Checks if the Seldon Deployment is in an available state.
Returns:
True if the Seldon Deployment is available, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.AVAILABLE
def is_failed(self) -> bool:
"""Checks if the Seldon Deployment is in a failed state.
Returns:
True if the Seldon Deployment is failed, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.FAILED
def get_error(self) -> Optional[str]:
"""Get a message describing the error, if in an error state.
Returns:
A message describing the error, if in an error state, otherwise
None.
"""
if self.status and self.is_failed():
return self.status.description
return None
def get_pending_message(self) -> Optional[str]:
"""Get a message describing the pending conditions of the Seldon Deployment.
Returns:
A message describing the pending condition of the Seldon
Deployment, or None, if no conditions are pending.
"""
if not self.status or not self.status.conditions:
return None
ready_condition_message = [
c.message
for c in self.status.conditions
if c.type == "Ready" and not c.status
]
if not ready_condition_message:
return None
return ready_condition_message[0]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
name: str
property
readonly
Returns the name of this Seldon Deployment.
This is just a shortcut for self.metadata.name
.
Returns:
Type | Description |
---|---|
str |
The name of this Seldon Deployment. |
state: SeldonDeploymentStatusState
property
readonly
The state of the Seldon Deployment.
Returns:
Type | Description |
---|---|
SeldonDeploymentStatusState |
The state of the Seldon Deployment. |
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
__str__(self)
special
Returns a string representation of the Seldon Deployment.
Returns:
Type | Description |
---|---|
str |
A string representation of the Seldon Deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def __str__(self) -> str:
"""Returns a string representation of the Seldon Deployment.
Returns:
A string representation of the Seldon Deployment.
"""
return json.dumps(self.dict(exclude_none=True), indent=4)
build(name=None, model_uri=None, model_name=None, implementation=None, secret_name=None, labels=None, annotations=None)
classmethod
Build a basic Seldon Deployment object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
The name of the Seldon Deployment. If not explicitly passed, a unique name is autogenerated. |
None |
model_uri |
Optional[str] |
The URI of the model. |
None |
model_name |
Optional[str] |
The name of the model. |
None |
implementation |
Optional[str] |
The implementation of the model. |
None |
secret_name |
Optional[str] |
The name of the Kubernetes secret containing environment variable values (e.g. with credentials for the artifact store) to use with the deployment service. |
None |
labels |
Optional[Dict[str, str]] |
A dictionary of labels to apply to the Seldon Deployment. |
None |
annotations |
Optional[Dict[str, str]] |
A dictionary of annotations to apply to the Seldon Deployment. |
None |
Returns:
Type | Description |
---|---|
SeldonDeployment |
A minimal SeldonDeployment object built from the provided parameters. |
Source code in zenml/integrations/seldon/seldon_client.py
@classmethod
def build(
cls,
name: Optional[str] = None,
model_uri: Optional[str] = None,
model_name: Optional[str] = None,
implementation: Optional[str] = None,
secret_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
) -> "SeldonDeployment":
"""Build a basic Seldon Deployment object.
Args:
name: The name of the Seldon Deployment. If not explicitly passed,
a unique name is autogenerated.
model_uri: The URI of the model.
model_name: The name of the model.
implementation: The implementation of the model.
secret_name: The name of the Kubernetes secret containing
environment variable values (e.g. with credentials for the
artifact store) to use with the deployment service.
labels: A dictionary of labels to apply to the Seldon Deployment.
annotations: A dictionary of annotations to apply to the Seldon
Deployment.
Returns:
A minimal SeldonDeployment object built from the provided
parameters.
"""
if not name:
name = f"zenml-{time.time()}"
if labels is None:
labels = {}
if annotations is None:
annotations = {}
return SeldonDeployment(
metadata=SeldonDeploymentMetadata(
name=name, labels=labels, annotations=annotations
),
spec=SeldonDeploymentSpec(
name=name,
predictors=[
SeldonDeploymentPredictor(
name=model_name or "",
graph=SeldonDeploymentPredictiveUnit(
name="default",
type=SeldonDeploymentPredictiveUnitType.MODEL,
modelUri=model_uri or "",
implementation=implementation or "",
envSecretRefName=secret_name,
),
)
],
),
)
get_error(self)
Get a message describing the error, if in an error state.
Returns:
Type | Description |
---|---|
Optional[str] |
A message describing the error, if in an error state, otherwise None. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_error(self) -> Optional[str]:
"""Get a message describing the error, if in an error state.
Returns:
A message describing the error, if in an error state, otherwise
None.
"""
if self.status and self.is_failed():
return self.status.description
return None
get_pending_message(self)
Get a message describing the pending conditions of the Seldon Deployment.
Returns:
Type | Description |
---|---|
Optional[str] |
A message describing the pending condition of the Seldon Deployment, or None, if no conditions are pending. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_pending_message(self) -> Optional[str]:
"""Get a message describing the pending conditions of the Seldon Deployment.
Returns:
A message describing the pending condition of the Seldon
Deployment, or None, if no conditions are pending.
"""
if not self.status or not self.status.conditions:
return None
ready_condition_message = [
c.message
for c in self.status.conditions
if c.type == "Ready" and not c.status
]
if not ready_condition_message:
return None
return ready_condition_message[0]
is_available(self)
Checks if the Seldon Deployment is in an available state.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is available, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_available(self) -> bool:
"""Checks if the Seldon Deployment is in an available state.
Returns:
True if the Seldon Deployment is available, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.AVAILABLE
is_failed(self)
Checks if the Seldon Deployment is in a failed state.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is failed, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_failed(self) -> bool:
"""Checks if the Seldon Deployment is in a failed state.
Returns:
True if the Seldon Deployment is failed, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.FAILED
is_managed_by_zenml(self)
Checks if this Seldon Deployment is managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the app
label value to zenml
.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is managed by ZenML, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_managed_by_zenml(self) -> bool:
"""Checks if this Seldon Deployment is managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
Returns:
True if the Seldon Deployment is managed by ZenML, False
otherwise.
"""
return self.metadata.labels.get("app") == "zenml"
is_pending(self)
Checks if the Seldon Deployment is in a pending state.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is pending, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_pending(self) -> bool:
"""Checks if the Seldon Deployment is in a pending state.
Returns:
True if the Seldon Deployment is pending, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.CREATING
mark_as_managed_by_zenml(self)
Marks this Seldon Deployment as managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the app
label value to zenml
.
Source code in zenml/integrations/seldon/seldon_client.py
def mark_as_managed_by_zenml(self) -> None:
"""Marks this Seldon Deployment as managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
"""
self.metadata.labels["app"] = "zenml"
SeldonDeploymentExistsError (SeldonClientError)
Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentExistsError(SeldonClientError):
"""Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists."""
SeldonDeploymentMetadata (BaseModel)
pydantic-model
Metadata for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the Seldon Deployment. |
labels |
Dict[str, str] |
Kubernetes labels for the Seldon Deployment. |
annotations |
Dict[str, str] |
Kubernetes annotations for the Seldon Deployment. |
creationTimestamp |
Optional[str] |
the creation timestamp of the Seldon Deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentMetadata(BaseModel):
"""Metadata for a Seldon Deployment.
Attributes:
name: the name of the Seldon Deployment.
labels: Kubernetes labels for the Seldon Deployment.
annotations: Kubernetes annotations for the Seldon Deployment.
creationTimestamp: the creation timestamp of the Seldon Deployment.
"""
name: str
labels: Dict[str, str] = Field(default_factory=dict)
annotations: Dict[str, str] = Field(default_factory=dict)
creationTimestamp: Optional[str]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentNotFoundError (SeldonClientError)
Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentNotFoundError(SeldonClientError):
"""Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML."""
SeldonDeploymentPredictiveUnit (BaseModel)
pydantic-model
Seldon Deployment predictive unit.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the predictive unit. |
type |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnitType] |
predictive unit type. |
implementation |
Optional[str] |
the Seldon Core implementation used to serve the model. |
modelUri |
Optional[str] |
URI of the model (or models) to serve. |
serviceAccountName |
Optional[str] |
the name of the service account to associate with the predictive unit container. |
envSecretRefName |
Optional[str] |
the name of a Kubernetes secret that contains environment variables (e.g. credentials) to be configured for the predictive unit container. |
children |
List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnit] |
a list of child predictive units that together make up the model serving graph. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnit(BaseModel):
"""Seldon Deployment predictive unit.
Attributes:
name: the name of the predictive unit.
type: predictive unit type.
implementation: the Seldon Core implementation used to serve the model.
modelUri: URI of the model (or models) to serve.
serviceAccountName: the name of the service account to associate with
the predictive unit container.
envSecretRefName: the name of a Kubernetes secret that contains
environment variables (e.g. credentials) to be configured for the
predictive unit container.
children: a list of child predictive units that together make up the
model serving graph.
"""
name: str
type: Optional[
SeldonDeploymentPredictiveUnitType
] = SeldonDeploymentPredictiveUnitType.MODEL
implementation: Optional[str]
modelUri: Optional[str]
serviceAccountName: Optional[str]
envSecretRefName: Optional[str]
children: List["SeldonDeploymentPredictiveUnit"] = Field(
default_factory=list
)
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentPredictiveUnitType (StrEnum)
Predictive unit types for a Seldon Deployment.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnitType(StrEnum):
"""Predictive unit types for a Seldon Deployment."""
UNKNOWN_TYPE = "UNKNOWN_TYPE"
ROUTER = "ROUTER"
COMBINER = "COMBINER"
MODEL = "MODEL"
TRANSFORMER = "TRANSFORMER"
OUTPUT_TRANSFORMER = "OUTPUT_TRANSFORMER"
SeldonDeploymentPredictor (BaseModel)
pydantic-model
Seldon Deployment predictor.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the predictor. |
replicas |
int |
the number of pod replicas for the predictor. |
graph |
SeldonDeploymentPredictiveUnit |
the serving graph composed of one or more predictive units. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictor(BaseModel):
"""Seldon Deployment predictor.
Attributes:
name: the name of the predictor.
replicas: the number of pod replicas for the predictor.
graph: the serving graph composed of one or more predictive units.
"""
name: str
replicas: int = 1
graph: SeldonDeploymentPredictiveUnit = Field(
default_factory=SeldonDeploymentPredictiveUnit
)
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentSpec (BaseModel)
pydantic-model
Spec for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the Seldon Deployment. |
protocol |
Optional[str] |
the API protocol used for the Seldon Deployment. |
predictors |
List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictor] |
a list of predictors that make up the serving graph. |
replicas |
int |
the default number of pod replicas used for the predictors. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentSpec(BaseModel):
"""Spec for a Seldon Deployment.
Attributes:
name: the name of the Seldon Deployment.
protocol: the API protocol used for the Seldon Deployment.
predictors: a list of predictors that make up the serving graph.
replicas: the default number of pod replicas used for the predictors.
"""
name: str
protocol: Optional[str]
predictors: List[SeldonDeploymentPredictor]
replicas: int = 1
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentStatus (BaseModel)
pydantic-model
The status of a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
state |
SeldonDeploymentStatusState |
the current state of the Seldon Deployment. |
description |
Optional[str] |
a human-readable description of the current state. |
replicas |
Optional[int] |
the current number of running pod replicas |
address |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusAddress] |
the address where the Seldon Deployment API can be accessed. |
conditions |
List[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusCondition] |
the list of Kubernetes conditions for the Seldon Deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatus(BaseModel):
"""The status of a Seldon Deployment.
Attributes:
state: the current state of the Seldon Deployment.
description: a human-readable description of the current state.
replicas: the current number of running pod replicas
address: the address where the Seldon Deployment API can be accessed.
conditions: the list of Kubernetes conditions for the Seldon Deployment.
"""
state: SeldonDeploymentStatusState = SeldonDeploymentStatusState.UNKNOWN
description: Optional[str]
replicas: Optional[int]
address: Optional[SeldonDeploymentStatusAddress]
conditions: List[SeldonDeploymentStatusCondition]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentStatusAddress (BaseModel)
pydantic-model
The status address for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
url |
str |
the URL where the Seldon Deployment API can be accessed internally. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusAddress(BaseModel):
"""The status address for a Seldon Deployment.
Attributes:
url: the URL where the Seldon Deployment API can be accessed internally.
"""
url: str
SeldonDeploymentStatusCondition (BaseModel)
pydantic-model
The Kubernetes status condition entry for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
type |
str |
Type of runtime condition. |
status |
bool |
Status of the condition. |
reason |
Optional[str] |
Brief CamelCase string containing reason for the condition's last transition. |
message |
Optional[str] |
Human-readable message indicating details about last transition. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusCondition(BaseModel):
"""The Kubernetes status condition entry for a Seldon Deployment.
Attributes:
type: Type of runtime condition.
status: Status of the condition.
reason: Brief CamelCase string containing reason for the condition's
last transition.
message: Human-readable message indicating details about last
transition.
"""
type: str
status: bool
reason: Optional[str]
message: Optional[str]
SeldonDeploymentStatusState (StrEnum)
Possible state values for a Seldon Deployment.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusState(StrEnum):
"""Possible state values for a Seldon Deployment."""
UNKNOWN = "Unknown"
AVAILABLE = "Available"
CREATING = "Creating"
FAILED = "Failed"
services
special
Initialization for Seldon services.
seldon_deployment
Implementation for the Seldon Deployment step.
SeldonDeploymentConfig (ServiceConfig)
pydantic-model
Seldon Core deployment service configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the model (or models) to serve. |
model_name |
str |
the name of the model. Multiple versions of the same model should use the same model name. |
implementation |
str |
the Seldon Core implementation used to serve the model. |
replicas |
int |
number of replicas to use for the prediction service. |
secret_name |
Optional[str] |
the name of a Kubernetes secret containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store). |
model_metadata |
Dict[str, Any] |
optional model metadata information (see https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html). |
extra_args |
Dict[str, Any] |
additional arguments to pass to the Seldon Core deployment resource configuration. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentConfig(ServiceConfig):
"""Seldon Core deployment service configuration.
Attributes:
model_uri: URI of the model (or models) to serve.
model_name: the name of the model. Multiple versions of the same model
should use the same model name.
implementation: the Seldon Core implementation used to serve the model.
replicas: number of replicas to use for the prediction service.
secret_name: the name of a Kubernetes secret containing additional
configuration parameters for the Seldon Core deployment (e.g.
credentials to access the Artifact Store).
model_metadata: optional model metadata information (see
https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html).
extra_args: additional arguments to pass to the Seldon Core deployment
resource configuration.
"""
model_uri: str = ""
model_name: str = "default"
# TODO [ENG-775]: have an enum of all supported Seldon Core implementations
implementation: str
replicas: int = 1
secret_name: Optional[str]
model_metadata: Dict[str, Any] = Field(default_factory=dict)
extra_args: Dict[str, Any] = Field(default_factory=dict)
def get_seldon_deployment_labels(self) -> Dict[str, str]:
"""Generate labels for the Seldon Core deployment from the service configuration.
These labels are attached to the Seldon Core deployment resource
and may be used as label selectors in lookup operations.
Returns:
The labels for the Seldon Core deployment.
"""
labels = {}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.implementation:
labels["zenml.model_type"] = self.implementation
SeldonClient.sanitize_labels(labels)
return labels
def get_seldon_deployment_annotations(self) -> Dict[str, str]:
"""Generate annotations for the Seldon Core deployment from the service configuration.
The annotations are used to store additional information about the
Seldon Core service that is associated with the deployment that is
not available in the labels. One annotation particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote Seldon deployment.
Returns:
The annotations for the Seldon Core deployment.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentConfig":
"""Recreate the configuration of a Seldon Core Service from a deployed instance.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service configuration corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible Seldon Core service configuration.
"""
config_data = deployment.metadata.annotations.get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded Seldon Core deployment resource contains an "
f"invalid or incompatible Seldon Core service configuration: "
f"{config_data}"
) from e
return service_config
create_from_deployment(deployment)
classmethod
Recreate the configuration of a Seldon Core Service from a deployed instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource. |
required |
Returns:
Type | Description |
---|---|
SeldonDeploymentConfig |
The Seldon Core service configuration corresponding to the given Seldon Core deployment resource. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible Seldon Core service configuration. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentConfig":
"""Recreate the configuration of a Seldon Core Service from a deployed instance.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service configuration corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible Seldon Core service configuration.
"""
config_data = deployment.metadata.annotations.get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded Seldon Core deployment resource contains an "
f"invalid or incompatible Seldon Core service configuration: "
f"{config_data}"
) from e
return service_config
get_seldon_deployment_annotations(self)
Generate annotations for the Seldon Core deployment from the service configuration.
The annotations are used to store additional information about the Seldon Core service that is associated with the deployment that is not available in the labels. One annotation particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote Seldon deployment.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The annotations for the Seldon Core deployment. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_annotations(self) -> Dict[str, str]:
"""Generate annotations for the Seldon Core deployment from the service configuration.
The annotations are used to store additional information about the
Seldon Core service that is associated with the deployment that is
not available in the labels. One annotation particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote Seldon deployment.
Returns:
The annotations for the Seldon Core deployment.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
get_seldon_deployment_labels(self)
Generate labels for the Seldon Core deployment from the service configuration.
These labels are attached to the Seldon Core deployment resource and may be used as label selectors in lookup operations.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The labels for the Seldon Core deployment. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_labels(self) -> Dict[str, str]:
"""Generate labels for the Seldon Core deployment from the service configuration.
These labels are attached to the Seldon Core deployment resource
and may be used as label selectors in lookup operations.
Returns:
The labels for the Seldon Core deployment.
"""
labels = {}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.implementation:
labels["zenml.model_type"] = self.implementation
SeldonClient.sanitize_labels(labels)
return labels
SeldonDeploymentService (BaseService)
pydantic-model
A service that represents a Seldon Core deployment server.
Attributes:
Name | Type | Description |
---|---|---|
config |
SeldonDeploymentConfig |
service configuration. |
status |
SeldonDeploymentServiceStatus |
service status. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentService(BaseService):
"""A service that represents a Seldon Core deployment server.
Attributes:
config: service configuration.
status: service status.
"""
SERVICE_TYPE = ServiceType(
name="seldon-deployment",
type="model-serving",
flavor="seldon",
description="Seldon Core prediction service",
)
config: SeldonDeploymentConfig = Field(
default_factory=SeldonDeploymentConfig
)
status: SeldonDeploymentServiceStatus = Field(
default_factory=SeldonDeploymentServiceStatus
)
def _get_client(self) -> SeldonClient:
"""Get the Seldon Core client from the active Seldon Core model deployer.
Returns:
The Seldon Core client.
"""
from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
SeldonModelDeployer,
)
model_deployer = SeldonModelDeployer.get_active_model_deployer()
return model_deployer.seldon_client
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the the current operational state of the Seldon Core deployment.
Returns:
The operational state of the Seldon Core deployment and a message
providing additional information about that state (e.g. a
description of the error, if one is encountered).
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
deployment = client.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return (ServiceState.INACTIVE, "")
if deployment.is_available():
return (
ServiceState.ACTIVE,
f"Seldon Core deployment '{name}' is available",
)
if deployment.is_failed():
return (
ServiceState.ERROR,
f"Seldon Core deployment '{name}' failed: "
f"{deployment.get_error()}",
)
pending_message = deployment.get_pending_message() or ""
return (
ServiceState.PENDING_STARTUP,
"Seldon Core deployment is being created: " + pending_message,
)
@property
def seldon_deployment_name(self) -> str:
"""Get the name of the Seldon Core deployment.
It should return the one that uniquely corresponds to this service instance.
Returns:
The name of the Seldon Core deployment.
"""
return f"zenml-{str(self.uuid)}"
def _get_seldon_deployment_labels(self) -> Dict[str, str]:
"""Generate the labels for the Seldon Core deployment from the service configuration.
Returns:
The labels for the Seldon Core deployment.
"""
labels = self.config.get_seldon_deployment_labels()
labels["zenml.service_uuid"] = str(self.uuid)
SeldonClient.sanitize_labels(labels)
return labels
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentService":
"""Recreate a Seldon Core service from a Seldon Core deployment resource.
It should then update their operational status.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected service_uuid label.
"""
config = SeldonDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.labels.get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
def provision(self) -> None:
"""Provision or update remote Seldon Core deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
name = self.seldon_deployment_name
deployment = SeldonDeployment.build(
name=name,
model_uri=self.config.model_uri,
model_name=self.config.model_name,
implementation=self.config.implementation,
secret_name=self.config.secret_name,
labels=self._get_seldon_deployment_labels(),
annotations=self.config.get_seldon_deployment_annotations(),
)
deployment.spec.replicas = self.config.replicas
deployment.spec.predictors[0].replicas = self.config.replicas
# check if the Seldon deployment already exists
try:
client.get_deployment(name=name)
# update the existing deployment
client.update_deployment(deployment)
except SeldonDeploymentNotFoundError:
# create the deployment
client.create_deployment(deployment=deployment)
def deprovision(self, force: bool = False) -> None:
"""Deprovision the remote Seldon Core deployment instance.
Args:
force: if True, the remote deployment instance will be
forcefully deprovisioned.
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
client.delete_deployment(name=name, force=force)
except SeldonDeploymentNotFoundError:
pass
def get_logs(
self,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core model deployment.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_client().get_deployment_logs(
self.seldon_deployment_name,
follow=follow,
tail=tail,
)
@property
def prediction_url(self) -> Optional[str]:
"""The prediction URI exposed by the prediction service.
Returns:
The prediction URI exposed by the prediction service, or None if
the service is not yet ready.
"""
from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
SeldonModelDeployer,
)
if not self.is_running:
return None
namespace = self._get_client().namespace
model_deployer = SeldonModelDeployer.get_active_model_deployer()
return os.path.join(
model_deployer.base_url,
"seldon",
namespace,
self.seldon_deployment_name,
"api/v0.1/predictions",
)
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"Seldon prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
response = requests.post(
self.prediction_url,
json={"data": {"ndarray": request.tolist()}},
)
response.raise_for_status()
return np.array(response.json()["data"]["ndarray"])
prediction_url: Optional[str]
property
readonly
The prediction URI exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction URI exposed by the prediction service, or None if the service is not yet ready. |
seldon_deployment_name: str
property
readonly
Get the name of the Seldon Core deployment.
It should return the one that uniquely corresponds to this service instance.
Returns:
Type | Description |
---|---|
str |
The name of the Seldon Core deployment. |
check_status(self)
Check the the current operational state of the Seldon Core deployment.
Returns:
Type | Description |
---|---|
Tuple[zenml.services.service_status.ServiceState, str] |
The operational state of the Seldon Core deployment and a message providing additional information about that state (e.g. a description of the error, if one is encountered). |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the the current operational state of the Seldon Core deployment.
Returns:
The operational state of the Seldon Core deployment and a message
providing additional information about that state (e.g. a
description of the error, if one is encountered).
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
deployment = client.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return (ServiceState.INACTIVE, "")
if deployment.is_available():
return (
ServiceState.ACTIVE,
f"Seldon Core deployment '{name}' is available",
)
if deployment.is_failed():
return (
ServiceState.ERROR,
f"Seldon Core deployment '{name}' failed: "
f"{deployment.get_error()}",
)
pending_message = deployment.get_pending_message() or ""
return (
ServiceState.PENDING_STARTUP,
"Seldon Core deployment is being created: " + pending_message,
)
create_from_deployment(deployment)
classmethod
Recreate a Seldon Core service from a Seldon Core deployment resource.
It should then update their operational status.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource. |
required |
Returns:
Type | Description |
---|---|
SeldonDeploymentService |
The Seldon Core service corresponding to the given Seldon Core deployment resource. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected service_uuid label. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentService":
"""Recreate a Seldon Core service from a Seldon Core deployment resource.
It should then update their operational status.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected service_uuid label.
"""
config = SeldonDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.labels.get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
deprovision(self, force=False)
Deprovision the remote Seldon Core deployment instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
force |
bool |
if True, the remote deployment instance will be forcefully deprovisioned. |
False |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def deprovision(self, force: bool = False) -> None:
"""Deprovision the remote Seldon Core deployment instance.
Args:
force: if True, the remote deployment instance will be
forcefully deprovisioned.
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
client.delete_deployment(name=name, force=force)
except SeldonDeploymentNotFoundError:
pass
get_logs(self, follow=False, tail=None)
Get the logs of a Seldon Core model deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_logs(
self,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core model deployment.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_client().get_deployment_logs(
self.seldon_deployment_name,
follow=follow,
tail=tail,
)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
NDArray[Any] |
a numpy array representing the request |
required |
Returns:
Type | Description |
---|---|
NDArray[Any] |
A numpy array representing the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not yet ready. |
ValueError |
if the prediction_url is not set. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"Seldon prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
response = requests.post(
self.prediction_url,
json={"data": {"ndarray": request.tolist()}},
)
response.raise_for_status()
return np.array(response.json()["data"]["ndarray"])
provision(self)
Provision or update remote Seldon Core deployment instance.
This should then match the current configuration.
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def provision(self) -> None:
"""Provision or update remote Seldon Core deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
name = self.seldon_deployment_name
deployment = SeldonDeployment.build(
name=name,
model_uri=self.config.model_uri,
model_name=self.config.model_name,
implementation=self.config.implementation,
secret_name=self.config.secret_name,
labels=self._get_seldon_deployment_labels(),
annotations=self.config.get_seldon_deployment_annotations(),
)
deployment.spec.replicas = self.config.replicas
deployment.spec.predictors[0].replicas = self.config.replicas
# check if the Seldon deployment already exists
try:
client.get_deployment(name=name)
# update the existing deployment
client.update_deployment(deployment)
except SeldonDeploymentNotFoundError:
# create the deployment
client.create_deployment(deployment=deployment)
SeldonDeploymentServiceStatus (ServiceStatus)
pydantic-model
Seldon Core deployment service status.
Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentServiceStatus(ServiceStatus):
"""Seldon Core deployment service status."""
steps
special
Initialization for Seldon steps.
seldon_deployer
Implementation of the Seldon Deployer step.
SeldonDeployerStepConfig (BaseStepConfig)
pydantic-model
Seldon model deployer step configuration.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
SeldonDeploymentConfig |
Seldon Core deployment service configuration. |
secrets |
a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables. |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepConfig(BaseStepConfig):
"""Seldon model deployer step configuration.
Attributes:
service_config: Seldon Core deployment service configuration.
secrets: a list of ZenML secrets containing additional configuration
parameters for the Seldon Core deployment (e.g. credentials to
access the Artifact Store where the models are stored). If supplied,
the information fetched from these secrets is passed to the Seldon
Core deployment server as a list of environment variables.
"""
service_config: SeldonDeploymentConfig
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
seldon_model_deployer_step (BaseStep)
Seldon Core model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
config |
configuration for the deployer step |
required | |
model |
the model artifact to deploy |
required | |
context |
the step context |
required |
Returns:
Type | Description |
---|---|
Seldon Core deployment service |
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Seldon model deployer step configuration.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
SeldonDeploymentConfig |
Seldon Core deployment service configuration. |
secrets |
a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables. |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepConfig(BaseStepConfig):
"""Seldon model deployer step configuration.
Attributes:
service_config: Seldon Core deployment service configuration.
secrets: a list of ZenML secrets containing additional configuration
parameters for the Seldon Core deployment (e.g. credentials to
access the Artifact Store where the models are stored). If supplied,
the information fetched from these secrets is passed to the Seldon
Core deployment server as a list of environment variables.
"""
service_config: SeldonDeploymentConfig
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, config, context, model)
staticmethod
Seldon Core model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
config |
SeldonDeployerStepConfig |
configuration for the deployer step |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
context |
StepContext |
the step context |
required |
Returns:
Type | Description |
---|---|
SeldonDeploymentService |
Seldon Core deployment service |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
@step(enable_cache=False)
def seldon_model_deployer_step(
deploy_decision: bool,
config: SeldonDeployerStepConfig,
context: StepContext,
model: ModelArtifact,
) -> SeldonDeploymentService:
"""Seldon Core model deployer pipeline step.
This step can be used in a pipeline to implement continuous
deployment for a ML model with Seldon Core.
Args:
deploy_decision: whether to deploy the model or not
config: configuration for the deployer step
model: the model artifact to deploy
context: the step context
Returns:
Seldon Core deployment service
"""
model_deployer = SeldonModelDeployer.get_active_model_deployer()
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
pipeline_run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# update the step configuration with the real pipeline runtime information
config.service_config.pipeline_name = pipeline_name
config.service_config.pipeline_run_id = pipeline_run_id
config.service_config.pipeline_step_name = step_name
def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig:
"""Prepare the model files for model serving.
This creates and returns a Seldon service configuration for the model.
This function ensures that the model files are in the correct format
and file structure required by the Seldon Core server implementation
used for model serving.
Args:
model_uri: the URI of the model artifact being served
Returns:
The URL to the model ready for serving.
Raises:
RuntimeError: if the model files were not found
"""
served_model_uri = os.path.join(
context.get_output_artifact_uri(), "seldon"
)
fileio.makedirs(served_model_uri)
# TODO [ENG-773]: determine how to formalize how models are organized into
# folders and sub-folders depending on the model type/format and the
# Seldon Core protocol used to serve the model.
# TODO [ENG-791]: auto-detect built-in Seldon server implementation
# from the model artifact type
# TODO [ENG-792]: validate the model artifact type against the
# supported built-in Seldon server implementations
if config.service_config.implementation == "TENSORFLOW_SERVER":
# the TensorFlow server expects model artifacts to be
# stored in numbered subdirectories, each representing a model
# version
io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
elif config.service_config.implementation == "SKLEARN_SERVER":
# the sklearn server expects model artifacts to be
# stored in a file called model.joblib
model_uri = os.path.join(model.uri, "model")
if not fileio.exists(model.uri):
raise RuntimeError(
f"Expected sklearn model artifact was not found at "
f"{model_uri}"
)
fileio.copy(
model_uri, os.path.join(served_model_uri, "model.joblib")
)
else:
# default treatment for all other server implementations is to
# simply reuse the model from the artifact store path where it
# is originally stored
served_model_uri = model_uri
service_config = config.service_config.copy()
service_config.model_uri = served_model_uri
return service_config
# fetch existing services with same pipeline name, step name and
# model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=config.service_config.model_name,
)
# even when the deploy decision is negative, if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{config.service_config.model_name}'..."
)
service = cast(SeldonDeploymentService, existing_services[0])
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure that
# a model server is available at all times
if not service.is_running:
service.start(timeout=config.timeout)
return service
# invoke the Seldon Core model deployer to create a new service
# or update an existing one that was previously deployed for the same
# model
service_config = prepare_service_config(model.uri)
service = cast(
SeldonDeploymentService,
model_deployer.deploy_model(
service_config, replace=True, timeout=config.timeout
),
)
logger.info(
f"Seldon deployment service started and reachable at:\n"
f" {service.prediction_url}\n"
)
return service
sklearn
special
Initialization of the sklearn integration.
SklearnIntegration (Integration)
Definition of sklearn integration for ZenML.
Source code in zenml/integrations/sklearn/__init__.py
class SklearnIntegration(Integration):
"""Definition of sklearn integration for ZenML."""
NAME = SKLEARN
REQUIREMENTS = ["scikit-learn"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/sklearn/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa
helpers
special
Initialization for helper functions for the sklearn digits dataset.
digits
Helper functions for the sklearn digits dataset.
get_digits()
Returns the digits dataset in the form of a tuple of numpy arrays.
Returns:
Type | Description |
---|---|
Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int64], NDArray[np.int64]] |
Tuple of (training_images, testing_images, training_labels, testing_labels) |
Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits() -> Tuple[
"NDArray[np.float64]",
"NDArray[np.float64]",
"NDArray[np.int64]",
"NDArray[np.int64]",
]:
"""Returns the digits dataset in the form of a tuple of numpy arrays.
Returns:
Tuple of (training_images, testing_images, training_labels, testing_labels)
"""
digits = load_digits()
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.5, shuffle=False
)
return X_train, X_test, y_train, y_test
get_digits_model()
Creates a support vector classifier for digits dataset.
Returns:
Type | Description |
---|---|
ClassifierMixin |
A support vector classifier. |
Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits_model() -> ClassifierMixin:
"""Creates a support vector classifier for digits dataset.
Returns:
A support vector classifier.
"""
return SVC(gamma=0.001)
materializers
special
Initialization of the sklearn materializer.
sklearn_materializer
Implementation of the sklearn materializer.
SklearnMaterializer (BaseMaterializer)
Materializer to read data to and from sklearn.
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
class SklearnMaterializer(BaseMaterializer):
"""Materializer to read data to and from sklearn."""
ASSOCIATED_TYPES = (
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(
self, data_type: Type[Any]
) -> Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
]:
"""Reads a base sklearn model from a pickle file.
Args:
data_type: The type of the model.
Returns:
The model.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "rb") as fid:
clf = pickle.load(fid)
return clf
def handle_return(
self,
clf: Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
],
) -> None:
"""Creates a pickle for a sklearn model.
Args:
clf: A sklearn model.
"""
super().handle_return(clf)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "wb") as fid:
pickle.dump(clf, fid)
handle_input(self, data_type)
Reads a base sklearn model from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model. |
required |
Returns:
Type | Description |
---|---|
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin] |
The model. |
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_input(
self, data_type: Type[Any]
) -> Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
]:
"""Reads a base sklearn model from a pickle file.
Args:
data_type: The type of the model.
Returns:
The model.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "rb") as fid:
clf = pickle.load(fid)
return clf
handle_return(self, clf)
Creates a pickle for a sklearn model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clf |
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin] |
A sklearn model. |
required |
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_return(
self,
clf: Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
],
) -> None:
"""Creates a pickle for a sklearn model.
Args:
clf: A sklearn model.
"""
super().handle_return(clf)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "wb") as fid:
pickle.dump(clf, fid)
steps
special
Initialization of the sklearn standard steps.
sklearn_evaluator
Implementation of the sklearn evaluator step.
SklearnEvaluator (BaseEvaluatorStep)
Simple sklearn evaluator step implementation.
This uses sklearn to evaluate the performance of a given model on a given test dataset.
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluator(BaseEvaluatorStep):
"""Simple sklearn evaluator step implementation.
This uses sklearn to evaluate the performance of a given model on a given
test dataset.
"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: BaseEstimator,
config: SklearnEvaluatorConfig,
) -> dict: # type: ignore[type-arg]
"""Method which is responsible for the computation of the evaluation.
Args:
dataset: a pandas DataFrame which represents the test dataset
model: a trained sklearn model
config: the configuration for the step
Returns:
a dictionary which has the evaluation report
"""
labels = dataset.pop(config.label_class_column)
predictions = model.predict(dataset)
predicted_classes = [1 if v > 0.5 else 0 for v in predictions]
report = classification_report(
labels, predicted_classes, output_dict=True
)
return report # type: ignore[no-any-return]
CONFIG_CLASS (BaseEvaluatorConfig)
pydantic-model
Config class for the sklearn evaluator.
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
"""Config class for the sklearn evaluator."""
label_class_column: str
entrypoint(self, dataset, model, config)
Method which is responsible for the computation of the evaluation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a pandas DataFrame which represents the test dataset |
required |
model |
BaseEstimator |
a trained sklearn model |
required |
config |
SklearnEvaluatorConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
dict |
a dictionary which has the evaluation report |
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: BaseEstimator,
config: SklearnEvaluatorConfig,
) -> dict: # type: ignore[type-arg]
"""Method which is responsible for the computation of the evaluation.
Args:
dataset: a pandas DataFrame which represents the test dataset
model: a trained sklearn model
config: the configuration for the step
Returns:
a dictionary which has the evaluation report
"""
labels = dataset.pop(config.label_class_column)
predictions = model.predict(dataset)
predicted_classes = [1 if v > 0.5 else 0 for v in predictions]
report = classification_report(
labels, predicted_classes, output_dict=True
)
return report # type: ignore[no-any-return]
SklearnEvaluatorConfig (BaseEvaluatorConfig)
pydantic-model
Config class for the sklearn evaluator.
Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
"""Config class for the sklearn evaluator."""
label_class_column: str
sklearn_splitter
Implementation of the sklearn splitter.
SklearnSplitter (BaseSplitStep)
A simple sklearn splitter step implementation.
This uses sklearn to split a given dataset into train, test and validation splits.
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitter(BaseSplitStep):
"""A simple sklearn splitter step implementation.
This uses sklearn to split a given dataset into train, test and validation
splits.
"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: SklearnSplitterConfig,
) -> Output( # type:ignore[valid-type]
train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
"""Method which is responsible for the splitting logic.
Args:
dataset: a pandas DataFrame which entire dataset
config: the configuration for the step
Returns:
three DataFrames representing the splits
Raises:
KeyError: if the wrong configuration is used
ValueError: if the ratios are not valid
"""
if (
any(
[
split not in config.ratios
for split in ["train", "test", "validation"]
]
)
or len(config.ratios) != 3
):
raise KeyError(
f"Make sure that you only use 'train', 'test' and "
f"'validation' as keys in the ratios dict. Current keys: "
f"{config.ratios.keys()}"
)
if sum(config.ratios.values()) != 1:
raise ValueError(
f"Make sure that the ratios sum up to 1. Current "
f"ratios: {config.ratios}"
)
train_dataset, test_dataset = train_test_split(
dataset, test_size=config.ratios["test"]
)
train_dataset, val_dataset = train_test_split(
train_dataset,
test_size=(
config.ratios["validation"]
/ (config.ratios["validation"] + config.ratios["train"])
),
)
return train_dataset, test_dataset, val_dataset
CONFIG_CLASS (BaseSplitStepConfig)
pydantic-model
Config class for the sklearn splitter.
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
"""Config class for the sklearn splitter."""
ratios: Dict[str, float]
entrypoint(self, dataset, config)
Method which is responsible for the splitting logic.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a pandas DataFrame which entire dataset |
required |
config |
SklearnSplitterConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
<zenml.steps.step_output.Output object at 0x7fd4c867dcd0> |
three DataFrames representing the splits |
Exceptions:
Type | Description |
---|---|
KeyError |
if the wrong configuration is used |
ValueError |
if the ratios are not valid |
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: SklearnSplitterConfig,
) -> Output( # type:ignore[valid-type]
train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
"""Method which is responsible for the splitting logic.
Args:
dataset: a pandas DataFrame which entire dataset
config: the configuration for the step
Returns:
three DataFrames representing the splits
Raises:
KeyError: if the wrong configuration is used
ValueError: if the ratios are not valid
"""
if (
any(
[
split not in config.ratios
for split in ["train", "test", "validation"]
]
)
or len(config.ratios) != 3
):
raise KeyError(
f"Make sure that you only use 'train', 'test' and "
f"'validation' as keys in the ratios dict. Current keys: "
f"{config.ratios.keys()}"
)
if sum(config.ratios.values()) != 1:
raise ValueError(
f"Make sure that the ratios sum up to 1. Current "
f"ratios: {config.ratios}"
)
train_dataset, test_dataset = train_test_split(
dataset, test_size=config.ratios["test"]
)
train_dataset, val_dataset = train_test_split(
train_dataset,
test_size=(
config.ratios["validation"]
/ (config.ratios["validation"] + config.ratios["train"])
),
)
return train_dataset, test_dataset, val_dataset
SklearnSplitterConfig (BaseSplitStepConfig)
pydantic-model
Config class for the sklearn splitter.
Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
"""Config class for the sklearn splitter."""
ratios: Dict[str, float]
sklearn_standard_scaler
Implementation of the sklearn standard scaler step.
SklearnStandardScaler (BasePreprocessorStep)
Simple StandardScaler step implementation.
This uses the StandardScaler from sklearn to transform the numeric columns of a pd.DataFrame.
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScaler(BasePreprocessorStep):
"""Simple StandardScaler step implementation.
This uses the StandardScaler from sklearn to transform the numeric columns
of a pd.DataFrame.
"""
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
test_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
statistics: pd.DataFrame,
schema: pd.DataFrame,
config: SklearnStandardScalerConfig,
) -> Output( # type:ignore[valid-type]
train_transformed=pd.DataFrame,
test_transformed=pd.DataFrame,
validation_transformed=pd.DataFrame,
):
"""Main entrypoint function for the StandardScaler.
Args:
train_dataset: pd.DataFrame, the training dataset
test_dataset: pd.DataFrame, the test dataset
validation_dataset: pd.DataFrame, the validation dataset
statistics: pd.DataFrame, the statistics over the train dataset
schema: pd.DataFrame, the detected schema of the dataset
config: the configuration for the step
Returns:
the transformed train, test and validation datasets as pd.DataFrames
"""
schema_dict = {k: v[0] for k, v in schema.to_dict().items()}
# Exclude columns
feature_set = set(train_dataset.columns) - set(config.exclude_columns)
for feature, feature_type in schema_dict.items():
if feature_type != "int64" and feature_type != "float64":
feature_set.remove(feature)
logger.warning(
f"{feature} column is a not numeric, thus it is excluded "
f"from the standard scaling."
)
transform_feature_set = feature_set - set(config.ignore_columns)
# Transform the datasets
scaler = StandardScaler()
scaler.mean_ = statistics["mean"][transform_feature_set]
scaler.scale_ = statistics["std"][transform_feature_set]
train_dataset[list(transform_feature_set)] = scaler.transform(
train_dataset[transform_feature_set]
)
test_dataset[list(transform_feature_set)] = scaler.transform(
test_dataset[transform_feature_set]
)
validation_dataset[list(transform_feature_set)] = scaler.transform(
validation_dataset[transform_feature_set]
)
return train_dataset, test_dataset, validation_dataset
CONFIG_CLASS (BasePreprocessorConfig)
pydantic-model
Config class for the sklearn standard scaler.
ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
"""Config class for the sklearn standard scaler.
ignore_columns: a list of column names which should not be scaled
exclude_columns: a list of column names to be excluded from the dataset
"""
ignore_columns: List[str] = []
exclude_columns: List[str] = []
entrypoint(self, train_dataset, test_dataset, validation_dataset, statistics, schema, config)
Main entrypoint function for the StandardScaler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_dataset |
DataFrame |
pd.DataFrame, the training dataset |
required |
test_dataset |
DataFrame |
pd.DataFrame, the test dataset |
required |
validation_dataset |
DataFrame |
pd.DataFrame, the validation dataset |
required |
statistics |
DataFrame |
pd.DataFrame, the statistics over the train dataset |
required |
schema |
DataFrame |
pd.DataFrame, the detected schema of the dataset |
required |
config |
SklearnStandardScalerConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
<zenml.steps.step_output.Output object at 0x7fd4c860ba90> |
the transformed train, test and validation datasets as pd.DataFrames |
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
test_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
statistics: pd.DataFrame,
schema: pd.DataFrame,
config: SklearnStandardScalerConfig,
) -> Output( # type:ignore[valid-type]
train_transformed=pd.DataFrame,
test_transformed=pd.DataFrame,
validation_transformed=pd.DataFrame,
):
"""Main entrypoint function for the StandardScaler.
Args:
train_dataset: pd.DataFrame, the training dataset
test_dataset: pd.DataFrame, the test dataset
validation_dataset: pd.DataFrame, the validation dataset
statistics: pd.DataFrame, the statistics over the train dataset
schema: pd.DataFrame, the detected schema of the dataset
config: the configuration for the step
Returns:
the transformed train, test and validation datasets as pd.DataFrames
"""
schema_dict = {k: v[0] for k, v in schema.to_dict().items()}
# Exclude columns
feature_set = set(train_dataset.columns) - set(config.exclude_columns)
for feature, feature_type in schema_dict.items():
if feature_type != "int64" and feature_type != "float64":
feature_set.remove(feature)
logger.warning(
f"{feature} column is a not numeric, thus it is excluded "
f"from the standard scaling."
)
transform_feature_set = feature_set - set(config.ignore_columns)
# Transform the datasets
scaler = StandardScaler()
scaler.mean_ = statistics["mean"][transform_feature_set]
scaler.scale_ = statistics["std"][transform_feature_set]
train_dataset[list(transform_feature_set)] = scaler.transform(
train_dataset[transform_feature_set]
)
test_dataset[list(transform_feature_set)] = scaler.transform(
test_dataset[transform_feature_set]
)
validation_dataset[list(transform_feature_set)] = scaler.transform(
validation_dataset[transform_feature_set]
)
return train_dataset, test_dataset, validation_dataset
SklearnStandardScalerConfig (BasePreprocessorConfig)
pydantic-model
Config class for the sklearn standard scaler.
ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset
Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
"""Config class for the sklearn standard scaler.
ignore_columns: a list of column names which should not be scaled
exclude_columns: a list of column names to be excluded from the dataset
"""
ignore_columns: List[str] = []
exclude_columns: List[str] = []
slack
special
Slack integration for alerter components.
SlackIntegration (Integration)
Definition of a Slack integration for ZenML.
Implemented using Slack SDK.
Source code in zenml/integrations/slack/__init__.py
class SlackIntegration(Integration):
"""Definition of a Slack integration for ZenML.
Implemented using [Slack SDK](https://pypi.org/project/slack-sdk/).
"""
NAME = SLACK
REQUIREMENTS = ["slack-sdk>=3.16.1", "aiohttp>=3.8.1"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Slack integration.
Returns:
List of new flavors defined by the Slack integration.
"""
return [
FlavorWrapper(
name=SLACK_ALERTER_FLAVOR,
source="zenml.integrations.slack.alerters.slack_alerter.SlackAlerter",
type=StackComponentType.ALERTER,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the Slack integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of new flavors defined by the Slack integration. |
Source code in zenml/integrations/slack/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Slack integration.
Returns:
List of new flavors defined by the Slack integration.
"""
return [
FlavorWrapper(
name=SLACK_ALERTER_FLAVOR,
source="zenml.integrations.slack.alerters.slack_alerter.SlackAlerter",
type=StackComponentType.ALERTER,
integration=cls.NAME,
)
]
alerters
special
Alerter components defined by the Slack integration.
slack_alerter
Implementation for slack flavor of alerter component.
SlackAlerter (BaseAlerter)
pydantic-model
Send messages to Slack channels.
Attributes:
Name | Type | Description |
---|---|---|
slack_token |
str |
The Slack token tied to the Slack account to be used. |
Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerter(BaseAlerter):
"""Send messages to Slack channels.
Attributes:
slack_token: The Slack token tied to the Slack account to be used.
"""
slack_token: str
default_slack_channel_id: Optional[str] = None
# Class Configuration
FLAVOR: ClassVar[str] = SLACK_ALERTER_FLAVOR
def _get_channel_id(self, config: Optional[BaseAlerterStepConfig]) -> str:
"""Get the Slack channel ID to be used by post/ask.
Args:
config: Optional runtime configuration.
Returns:
ID of the Slack channel to be used.
Raises:
RuntimeError: if config is not of type `BaseAlerterStepConfig`.
ValueError: if a slack channel was neither defined in the config
nor in the slack alerter component.
"""
if not isinstance(config, BaseAlerterStepConfig):
raise RuntimeError(
"The config object must be of type `BaseAlerterStepConfig`."
)
if (
isinstance(config, SlackAlerterConfig)
and hasattr(config, "slack_channel_id")
and config.slack_channel_id is not None
):
return config.slack_channel_id
if self.default_slack_channel_id is not None:
return self.default_slack_channel_id
raise ValueError(
"Neither the `SlackAlerterConfig.slack_channel_id` in the runtime "
"configuration, nor the `default_slack_channel_id` in the alerter "
"stack component is specified. Please specify at least one."
)
def _get_approve_msg_options(
self, config: Optional[BaseAlerterStepConfig]
) -> List[str]:
"""Define which messages will lead to approval during ask().
Args:
config: Optional runtime configuration.
Returns:
Set of messages that lead to approval in alerter.ask().
"""
if (
isinstance(config, SlackAlerterConfig)
and hasattr(config, "approve_msg_options")
and config.approve_msg_options is not None
):
return config.approve_msg_options
return DEFAULT_APPROVE_MSG_OPTIONS
def _get_disapprove_msg_options(
self, config: Optional[BaseAlerterStepConfig]
) -> List[str]:
"""Define which messages will lead to disapproval during ask().
Args:
config: Optional runtime configuration.
Returns:
Set of messages that lead to disapproval in alerter.ask().
"""
if (
isinstance(config, SlackAlerterConfig)
and hasattr(config, "disapprove_msg_options")
and config.disapprove_msg_options is not None
):
return config.disapprove_msg_options
return DEFAULT_DISAPPROVE_MSG_OPTIONS
def post(
self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
"""Post a message to a Slack channel.
Args:
message: Message to be posted.
config: Optional runtime configuration.
Returns:
True if operation succeeded, else False
"""
slack_channel_id = self._get_channel_id(config=config)
client = WebClient(token=self.slack_token)
try:
response = client.chat_postMessage(
channel=slack_channel_id,
text=message,
)
return True
except SlackApiError as error:
response = error.response["error"]
logger.error(f"SlackAlerter.post() failed: {response}")
return False
def ask(
self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
"""Post a message to a Slack channel and wait for approval.
Args:
message: Initial message to be posted.
config: Optional runtime configuration.
Returns:
True if a user approved the operation, else False
"""
rtm = RTMClient(token=self.slack_token)
slack_channel_id = self._get_channel_id(config=config)
approved = False # will be modified by handle()
@RTMClient.run_on(event="hello") # type: ignore
def post_initial_message(**payload: Any) -> None:
"""Post an initial message in a channel and start listening.
Args:
payload: payload of the received Slack event.
"""
web_client = payload["web_client"]
web_client.chat_postMessage(channel=slack_channel_id, text=message)
@RTMClient.run_on(event="message") # type: ignore
def handle(**payload: Any) -> None:
"""Listen / handle messages posted in the channel.
Args:
payload: payload of the received Slack event.
"""
event = payload["data"]
if event["channel"] == slack_channel_id:
# approve request (return True)
if event["text"] in self._get_approve_msg_options(config):
print(f"User {event['user']} approved on slack.")
nonlocal approved
approved = True
rtm.stop() # type: ignore
# disapprove request (return False)
elif event["text"] in self._get_disapprove_msg_options(config):
print(f"User {event['user']} disapproved on slack.")
rtm.stop() # type:ignore
# start another thread until `rtm.stop()` is called in handle()
rtm.start()
return approved
ask(self, message, config)
Post a message to a Slack channel and wait for approval.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
message |
str |
Initial message to be posted. |
required |
config |
Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepConfig] |
Optional runtime configuration. |
required |
Returns:
Type | Description |
---|---|
bool |
True if a user approved the operation, else False |
Source code in zenml/integrations/slack/alerters/slack_alerter.py
def ask(
self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
"""Post a message to a Slack channel and wait for approval.
Args:
message: Initial message to be posted.
config: Optional runtime configuration.
Returns:
True if a user approved the operation, else False
"""
rtm = RTMClient(token=self.slack_token)
slack_channel_id = self._get_channel_id(config=config)
approved = False # will be modified by handle()
@RTMClient.run_on(event="hello") # type: ignore
def post_initial_message(**payload: Any) -> None:
"""Post an initial message in a channel and start listening.
Args:
payload: payload of the received Slack event.
"""
web_client = payload["web_client"]
web_client.chat_postMessage(channel=slack_channel_id, text=message)
@RTMClient.run_on(event="message") # type: ignore
def handle(**payload: Any) -> None:
"""Listen / handle messages posted in the channel.
Args:
payload: payload of the received Slack event.
"""
event = payload["data"]
if event["channel"] == slack_channel_id:
# approve request (return True)
if event["text"] in self._get_approve_msg_options(config):
print(f"User {event['user']} approved on slack.")
nonlocal approved
approved = True
rtm.stop() # type: ignore
# disapprove request (return False)
elif event["text"] in self._get_disapprove_msg_options(config):
print(f"User {event['user']} disapproved on slack.")
rtm.stop() # type:ignore
# start another thread until `rtm.stop()` is called in handle()
rtm.start()
return approved
post(self, message, config)
Post a message to a Slack channel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
message |
str |
Message to be posted. |
required |
config |
Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepConfig] |
Optional runtime configuration. |
required |
Returns:
Type | Description |
---|---|
bool |
True if operation succeeded, else False |
Source code in zenml/integrations/slack/alerters/slack_alerter.py
def post(
self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
"""Post a message to a Slack channel.
Args:
message: Message to be posted.
config: Optional runtime configuration.
Returns:
True if operation succeeded, else False
"""
slack_channel_id = self._get_channel_id(config=config)
client = WebClient(token=self.slack_token)
try:
response = client.chat_postMessage(
channel=slack_channel_id,
text=message,
)
return True
except SlackApiError as error:
response = error.response["error"]
logger.error(f"SlackAlerter.post() failed: {response}")
return False
SlackAlerterConfig (BaseAlerterStepConfig)
pydantic-model
Slack alerter config.
Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerterConfig(BaseAlerterStepConfig):
"""Slack alerter config."""
# The ID of the Slack channel to use for communication.
slack_channel_id: Optional[str] = None
# Set of messages that lead to approval in alerter.ask()
approve_msg_options: Optional[List[str]] = None
# Set of messages that lead to disapproval in alerter.ask()
disapprove_msg_options: Optional[List[str]] = None
tensorflow
special
Initialization for TensorFlow integration.
TensorflowIntegration (Integration)
Definition of Tensorflow integration for ZenML.
Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
"""Definition of Tensorflow integration for ZenML."""
NAME = TENSORFLOW
REQUIREMENTS = ["tensorflow==2.8.0", "tensorflow_io==0.24.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
# need to import this explicitly to load the Tensorflow file IO support
# for S3 and other file systems
import tensorflow_io # type: ignore [import]
from zenml.integrations.tensorflow import materializers # noqa
from zenml.integrations.tensorflow import services # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
# need to import this explicitly to load the Tensorflow file IO support
# for S3 and other file systems
import tensorflow_io # type: ignore [import]
from zenml.integrations.tensorflow import materializers # noqa
from zenml.integrations.tensorflow import services # noqa
materializers
special
Initialization for the TensorFlow materializers.
keras_materializer
Implementation of the TensorFlow Keras materializer.
KerasMaterializer (BaseMaterializer)
Materializer to read/write Keras models.
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
"""Materializer to read/write Keras models."""
ASSOCIATED_TYPES = (keras.Model,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> keras.Model:
"""Reads and returns a Keras model after copying it to temporary path.
Args:
data_type: The type of the data to read.
Returns:
A tf.keras.Model model.
"""
super().handle_input(data_type)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
# Copy from artifact store to temporary directory
io_utils.copy_dir(self.artifact.uri, temp_dir.name)
# Load the model from the temporary directory
model = keras.models.load_model(temp_dir.name)
# Cleanup and return
fileio.rmtree(temp_dir.name)
return model
def handle_return(self, model: keras.Model) -> None:
"""Writes a keras model to the artifact store.
Args:
model: A tf.keras.Model model.
"""
super().handle_return(model)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
model.save(temp_dir.name)
io_utils.copy_dir(temp_dir.name, self.artifact.uri)
# Remove the temporary directory
fileio.rmtree(temp_dir.name)
handle_input(self, data_type)
Reads and returns a Keras model after copying it to temporary path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Model |
A tf.keras.Model model. |
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_input(self, data_type: Type[Any]) -> keras.Model:
"""Reads and returns a Keras model after copying it to temporary path.
Args:
data_type: The type of the data to read.
Returns:
A tf.keras.Model model.
"""
super().handle_input(data_type)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
# Copy from artifact store to temporary directory
io_utils.copy_dir(self.artifact.uri, temp_dir.name)
# Load the model from the temporary directory
model = keras.models.load_model(temp_dir.name)
# Cleanup and return
fileio.rmtree(temp_dir.name)
return model
handle_return(self, model)
Writes a keras model to the artifact store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Model |
A tf.keras.Model model. |
required |
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_return(self, model: keras.Model) -> None:
"""Writes a keras model to the artifact store.
Args:
model: A tf.keras.Model model.
"""
super().handle_return(model)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
model.save(temp_dir.name)
io_utils.copy_dir(temp_dir.name, self.artifact.uri)
# Remove the temporary directory
fileio.rmtree(temp_dir.name)
tf_dataset_materializer
Implementation of the TensorFlow dataset materializer.
TensorflowDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from tf.data.Dataset.
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from tf.data.Dataset."""
ASSOCIATED_TYPES = (tf.data.Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Any:
"""Reads data into tf.data.Dataset.
Args:
data_type: The type of the data to read.
Returns:
A tf.data.Dataset object.
"""
super().handle_input(data_type)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
return tf.data.experimental.load(path)
def handle_return(self, dataset: tf.data.Dataset) -> None:
"""Persists a tf.data.Dataset object.
Args:
dataset: The dataset to persist.
"""
super().handle_return(dataset)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
tf.data.experimental.save(
dataset, path, compression=None, shard_func=None
)
handle_input(self, data_type)
Reads data into tf.data.Dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Any |
A tf.data.Dataset object. |
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Any:
"""Reads data into tf.data.Dataset.
Args:
data_type: The type of the data to read.
Returns:
A tf.data.Dataset object.
"""
super().handle_input(data_type)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
return tf.data.experimental.load(path)
handle_return(self, dataset)
Persists a tf.data.Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DatasetV2 |
The dataset to persist. |
required |
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_return(self, dataset: tf.data.Dataset) -> None:
"""Persists a tf.data.Dataset object.
Args:
dataset: The dataset to persist.
"""
super().handle_return(dataset)
path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
tf.data.experimental.save(
dataset, path, compression=None, shard_func=None
)
services
special
Initialization for TensorFlow services.
tensorboard_service
Implementation of the TensorBoard service.
TensorboardService (LocalDaemonService)
pydantic-model
TensorBoard service.
This can be used to start a local TensorBoard server for one or more models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the TensorBoard service class |
config |
TensorboardServiceConfig |
service configuration |
endpoint |
LocalDaemonServiceEndpoint |
optional service endpoint |
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardService(LocalDaemonService):
"""TensorBoard service.
This can be used to start a local TensorBoard server for one or more models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the TensorBoard service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="tensorboard",
type="visualization",
flavor="tensorboard",
description="TensorBoard visualization service",
)
config: TensorboardServiceConfig
endpoint: LocalDaemonServiceEndpoint
def __init__(
self,
config: Union[TensorboardServiceConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialization for TensorBoard service.
Args:
config: service configuration
**attrs: additional attributes
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-697]: implement a service factory or builder for TensorBoard
# deployment services
if (
isinstance(config, TensorboardServiceConfig)
and "endpoint" not in attrs
):
endpoint = LocalDaemonServiceEndpoint(
config=LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path="",
use_head_request=True,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
"""Initialize and run the TensorBoard server."""
logger.info(
"Starting TensorBoard service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
tensorboard = program.TensorBoard(
plugins=default.get_plugins(),
subcommands=[uploader_subcommand.UploaderSubcommand()],
)
tensorboard.configure(
logdir=self.config.logdir,
port=self.endpoint.status.port,
host="localhost",
max_reload_threads=self.config.max_reload_threads,
reload_interval=self.config.reload_interval,
)
tensorboard.main()
except KeyboardInterrupt:
logger.info(
"TensorBoard service stopped. Resuming normal execution."
)
__init__(self, config, **attrs)
special
Initialization for TensorBoard service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[zenml.integrations.tensorflow.services.tensorboard_service.TensorboardServiceConfig, Dict[str, Any]] |
service configuration |
required |
**attrs |
Any |
additional attributes |
{} |
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def __init__(
self,
config: Union[TensorboardServiceConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialization for TensorBoard service.
Args:
config: service configuration
**attrs: additional attributes
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-697]: implement a service factory or builder for TensorBoard
# deployment services
if (
isinstance(config, TensorboardServiceConfig)
and "endpoint" not in attrs
):
endpoint = LocalDaemonServiceEndpoint(
config=LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path="",
use_head_request=True,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
run(self)
Initialize and run the TensorBoard server.
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def run(self) -> None:
"""Initialize and run the TensorBoard server."""
logger.info(
"Starting TensorBoard service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
tensorboard = program.TensorBoard(
plugins=default.get_plugins(),
subcommands=[uploader_subcommand.UploaderSubcommand()],
)
tensorboard.configure(
logdir=self.config.logdir,
port=self.endpoint.status.port,
host="localhost",
max_reload_threads=self.config.max_reload_threads,
reload_interval=self.config.reload_interval,
)
tensorboard.main()
except KeyboardInterrupt:
logger.info(
"TensorBoard service stopped. Resuming normal execution."
)
TensorboardServiceConfig (LocalDaemonServiceConfig)
pydantic-model
TensorBoard service configuration.
Attributes:
Name | Type | Description |
---|---|---|
logdir |
str |
location of TensorBoard log files. |
max_reload_threads |
int |
the max number of threads that TensorBoard can use to reload runs. Each thread reloads one run at a time. |
reload_interval |
int |
how often the backend should load more data, in seconds. Set to 0 to load just once at startup. |
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardServiceConfig(LocalDaemonServiceConfig):
"""TensorBoard service configuration.
Attributes:
logdir: location of TensorBoard log files.
max_reload_threads: the max number of threads that TensorBoard can use
to reload runs. Each thread reloads one run at a time.
reload_interval: how often the backend should load more data, in
seconds. Set to 0 to load just once at startup.
"""
logdir: str
max_reload_threads: int = 1
reload_interval: int = 5
steps
special
Initialization for TensorFlow standard steps.
tensorflow_trainer
Implementation of a TensorFlow trainer step.
TensorflowBinaryClassifier (BaseTrainerStep)
A TensorFlow binary classifier.
This simple step implementation creates a simple tensorflow feedforward neural network and trains it on a given pd.DataFrame dataset.
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifier(BaseTrainerStep):
"""A TensorFlow binary classifier.
This simple step implementation creates a simple tensorflow feedforward
neural network and trains it on a given pd.DataFrame dataset.
"""
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
"""Main entrypoint for the tensorflow trainer.
Args:
train_dataset: pd.DataFrame, the training dataset
validation_dataset: pd.DataFrame, the validation dataset
config: the configuration of the step
Returns:
the trained tf.keras.Model
"""
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
model.add(tf.keras.layers.Flatten())
last_layer = config.layers.pop()
for i, layer in enumerate(config.layers):
model.add(tf.keras.layers.Dense(layer, activation="relu"))
model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))
model.compile(
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=config.metrics,
)
train_target = train_dataset.pop(config.target_column)
validation_target = validation_dataset.pop(config.target_column)
model.fit(
x=train_dataset,
y=train_target,
validation_data=(validation_dataset, validation_target),
batch_size=config.batch_size,
epochs=config.epochs,
)
model.summary()
return model
CONFIG_CLASS (BaseTrainerConfig)
pydantic-model
Config class for the tensorflow trainer.
target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
"""Config class for the tensorflow trainer.
target_column: the name of the label column
layers: the number of units in the fully connected layers
input_shape: the shape of the input
learning_rate: the learning rate
metrics: the list of metrics to be computed
epochs: the number of epochs
batch_size: the size of the batch
"""
target_column: str
layers: List[int] = [256, 64, 1]
input_shape: Tuple[int] = (8,)
learning_rate: float = 0.001
metrics: List[str] = ["accuracy"]
epochs: int = 50
batch_size: int = 8
entrypoint(self, train_dataset, validation_dataset, config)
Main entrypoint for the tensorflow trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_dataset |
DataFrame |
pd.DataFrame, the training dataset |
required |
validation_dataset |
DataFrame |
pd.DataFrame, the validation dataset |
required |
config |
TensorflowBinaryClassifierConfig |
the configuration of the step |
required |
Returns:
Type | Description |
---|---|
Model |
the trained tf.keras.Model |
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
def entrypoint( # type: ignore[override]
self,
train_dataset: pd.DataFrame,
validation_dataset: pd.DataFrame,
config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
"""Main entrypoint for the tensorflow trainer.
Args:
train_dataset: pd.DataFrame, the training dataset
validation_dataset: pd.DataFrame, the validation dataset
config: the configuration of the step
Returns:
the trained tf.keras.Model
"""
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
model.add(tf.keras.layers.Flatten())
last_layer = config.layers.pop()
for i, layer in enumerate(config.layers):
model.add(tf.keras.layers.Dense(layer, activation="relu"))
model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))
model.compile(
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=config.metrics,
)
train_target = train_dataset.pop(config.target_column)
validation_target = validation_dataset.pop(config.target_column)
model.fit(
x=train_dataset,
y=train_target,
validation_data=(validation_dataset, validation_target),
batch_size=config.batch_size,
epochs=config.epochs,
)
model.summary()
return model
TensorflowBinaryClassifierConfig (BaseTrainerConfig)
pydantic-model
Config class for the tensorflow trainer.
target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch
Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
"""Config class for the tensorflow trainer.
target_column: the name of the label column
layers: the number of units in the fully connected layers
input_shape: the shape of the input
learning_rate: the learning rate
metrics: the list of metrics to be computed
epochs: the number of epochs
batch_size: the size of the batch
"""
target_column: str
layers: List[int] = [256, 64, 1]
input_shape: Tuple[int] = (8,)
learning_rate: float = 0.001
metrics: List[str] = ["accuracy"]
epochs: int = 50
batch_size: int = 8
visualizers
special
Initialization for TensorFlow visualizer.
tensorboard_visualizer
Implementation of a TensorFlow visualizer step.
TensorboardVisualizer (BaseStepVisualizer)
The implementation of a Whylogs Visualizer.
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
class TensorboardVisualizer(BaseStepVisualizer):
"""The implementation of a Whylogs Visualizer."""
@classmethod
def find_running_tensorboard_server(
cls, logdir: str
) -> Optional[TensorBoardInfo]:
"""Find a local TensorBoard server instance.
Finds when it is running for the supplied logdir location and return its
TCP port.
Args:
logdir: The logdir location where the TensorBoard server is running.
Returns:
The TensorBoardInfo describing the running TensorBoard server or
None if no server is running for the supplied logdir location.
"""
for server in get_all():
if (
server.logdir == logdir
and server.pid
and psutil.pid_exists(server.pid)
):
return server
return None
def visualize(
self,
object: StepView,
height: int = 800,
*args: Any,
**kwargs: Any,
) -> None:
"""Start a TensorBoard server.
Allows for the visualization of all models logged as artifacts by the
indicated step. The server will monitor and display all the models
logged by past and future step runs.
Args:
object: StepView fetched from run.get_step().
height: Height of the generated visualization.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return
if sys.platform == "win32":
# Daemon service functionality is currently not supported on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
else:
# start a new TensorBoard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
)
)
service.start(timeout=20)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return
def visualize_tensorboard(
self,
port: int,
height: int,
) -> None:
"""Generate a visualization of a TensorBoard.
Args:
port: the TCP port where the TensorBoard server is listening for
requests.
height: Height of the generated visualization.
"""
if Environment.in_notebook():
notebook.display(port, height=height)
return
print(
"You can visit:\n"
f"[italic green] http://localhost:{port}/[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
def stop(
self,
object: StepView,
) -> None:
"""Stop the TensorBoard server previously started for a pipeline step.
Args:
object: StepView fetched from run.get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
find_running_tensorboard_server(logdir)
classmethod
Find a local TensorBoard server instance.
Finds when it is running for the supplied logdir location and return its TCP port.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logdir |
str |
The logdir location where the TensorBoard server is running. |
required |
Returns:
Type | Description |
---|---|
Optional[tensorboard.manager.TensorBoardInfo] |
The TensorBoardInfo describing the running TensorBoard server or None if no server is running for the supplied logdir location. |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
@classmethod
def find_running_tensorboard_server(
cls, logdir: str
) -> Optional[TensorBoardInfo]:
"""Find a local TensorBoard server instance.
Finds when it is running for the supplied logdir location and return its
TCP port.
Args:
logdir: The logdir location where the TensorBoard server is running.
Returns:
The TensorBoardInfo describing the running TensorBoard server or
None if no server is running for the supplied logdir location.
"""
for server in get_all():
if (
server.logdir == logdir
and server.pid
and psutil.pid_exists(server.pid)
):
return server
return None
stop(self, object)
Stop the TensorBoard server previously started for a pipeline step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop(
self,
object: StepView,
) -> None:
"""Stop the TensorBoard server previously started for a pipeline step.
Args:
object: StepView fetched from run.get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
visualize(self, object, height=800, *args, **kwargs)
Start a TensorBoard server.
Allows for the visualization of all models logged as artifacts by the indicated step. The server will monitor and display all the models logged by past and future step runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
height |
int |
Height of the generated visualization. |
800 |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize(
self,
object: StepView,
height: int = 800,
*args: Any,
**kwargs: Any,
) -> None:
"""Start a TensorBoard server.
Allows for the visualization of all models logged as artifacts by the
indicated step. The server will monitor and display all the models
logged by past and future step runs.
Args:
object: StepView fetched from run.get_step().
height: Height of the generated visualization.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return
if sys.platform == "win32":
# Daemon service functionality is currently not supported on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
else:
# start a new TensorBoard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
)
)
service.start(timeout=20)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return
visualize_tensorboard(self, port, height)
Generate a visualization of a TensorBoard.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
port |
int |
the TCP port where the TensorBoard server is listening for requests. |
required |
height |
int |
Height of the generated visualization. |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(
self,
port: int,
height: int,
) -> None:
"""Generate a visualization of a TensorBoard.
Args:
port: the TCP port where the TensorBoard server is listening for
requests.
height: Height of the generated visualization.
"""
if Environment.in_notebook():
notebook.display(port, height=height)
return
print(
"You can visit:\n"
f"[italic green] http://localhost:{port}/[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
get_step(pipeline_name, step_name)
Get the StepView for the specified pipeline and step name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
step_name |
str |
The name of the step. |
required |
Returns:
Type | Description |
---|---|
StepView |
The StepView for the specified pipeline and step name. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the step is not found. |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def get_step(pipeline_name: str, step_name: str) -> StepView:
"""Get the StepView for the specified pipeline and step name.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
Returns:
The StepView for the specified pipeline and step name.
Raises:
RuntimeError: If the step is not found.
"""
repo = Repository()
pipeline = repo.get_pipeline(pipeline_name)
if pipeline is None:
raise RuntimeError(f"No pipeline with name `{pipeline_name}` was found")
last_run = pipeline.runs[-1]
step = last_run.get_step(step=step_name)
if step is None:
raise RuntimeError(
f"No pipeline step with name `{step_name}` was found in "
f"pipeline `{pipeline_name}`"
)
return cast(StepView, step)
stop_tensorboard_server(pipeline_name, step_name)
Stop the TensorBoard server previously started for a pipeline step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the pipeline |
required |
step_name |
str |
pipeline step name |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop_tensorboard_server(pipeline_name: str, step_name: str) -> None:
"""Stop the TensorBoard server previously started for a pipeline step.
Args:
pipeline_name: the name of the pipeline
step_name: pipeline step name
"""
step = get_step(pipeline_name, step_name)
TensorboardVisualizer().stop(step)
visualize_tensorboard(pipeline_name, step_name)
Start a TensorBoard server.
Allows for the visualization of all models logged as output by the named pipeline step. The server will monitor and display all the models logged by past and future step runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the pipeline |
required |
step_name |
str |
pipeline step name |
required |
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(pipeline_name: str, step_name: str) -> None:
"""Start a TensorBoard server.
Allows for the visualization of all models logged as output by the named
pipeline step. The server will monitor and display all the models logged by
past and future step runs.
Args:
pipeline_name: the name of the pipeline
step_name: pipeline step name
"""
step = get_step(pipeline_name, step_name)
TensorboardVisualizer().visualize(step)
utils
Utility functions for the integrations module.
get_integration_for_module(module_name)
Gets the integration class for a module inside an integration.
If the module given by module_name
is not part of a ZenML integration,
this method will return None
. If it is part of a ZenML integration,
it will return the integration class found inside the integration
init file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name |
str |
The name of the module to get the integration for. |
required |
Returns:
Type | Description |
---|---|
Optional[Type[zenml.integrations.integration.Integration]] |
The integration class for the module. |
Source code in zenml/integrations/utils.py
def get_integration_for_module(
module_name: str,
) -> Optional[Type[Integration]]:
"""Gets the integration class for a module inside an integration.
If the module given by `module_name` is not part of a ZenML integration,
this method will return `None`. If it is part of a ZenML integration,
it will return the integration class found inside the integration
__init__ file.
Args:
module_name: The name of the module to get the integration for.
Returns:
The integration class for the module.
"""
integration_prefix = "zenml.integrations."
if not module_name.startswith(integration_prefix):
return None
integration_module_name = ".".join(module_name.split(".", 3)[:3])
try:
integration_module = sys.modules[integration_module_name]
except KeyError:
integration_module = importlib.import_module(integration_module_name)
for name, member in inspect.getmembers(integration_module):
if (
member is not Integration
and isinstance(member, IntegrationMeta)
and issubclass(member, Integration)
):
return cast(Type[Integration], member)
return None
get_requirements_for_module(module_name)
Gets requirements for a module inside an integration.
If the module given by module_name
is not part of a ZenML integration,
this method will return an empty list. If it is part of a ZenML integration,
it will return the list of requirements specified inside the integration
class found inside the integration init file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name |
str |
The name of the module to get requirements for. |
required |
Returns:
Type | Description |
---|---|
List[str] |
A list of requirements for the module. |
Source code in zenml/integrations/utils.py
def get_requirements_for_module(module_name: str) -> List[str]:
"""Gets requirements for a module inside an integration.
If the module given by `module_name` is not part of a ZenML integration,
this method will return an empty list. If it is part of a ZenML integration,
it will return the list of requirements specified inside the integration
class found inside the integration __init__ file.
Args:
module_name: The name of the module to get requirements for.
Returns:
A list of requirements for the module.
"""
integration = get_integration_for_module(module_name)
return integration.REQUIREMENTS if integration else []
vault
special
Initialization for the Vault Secrets Manager integration.
The Vault secrets manager integration submodule provides a way to access the HashiCorp Vault secrets manager from within your ZenML pipeline runs.
VaultSecretsManagerIntegration (Integration)
Definition of HashiCorp Vault integration with ZenML.
Source code in zenml/integrations/vault/__init__.py
class VaultSecretsManagerIntegration(Integration):
"""Definition of HashiCorp Vault integration with ZenML."""
NAME = VAULT
REQUIREMENTS = ["hvac>=0.11.2"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Vault integration.
Returns:
List of stack component flavors.
"""
return [
FlavorWrapper(
name=VAULT_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.vault.secrets_manager.VaultSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the Vault integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors. |
Source code in zenml/integrations/vault/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Vault integration.
Returns:
List of stack component flavors.
"""
return [
FlavorWrapper(
name=VAULT_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.vault.secrets_manager.VaultSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
)
]
secrets_manager
special
HashiCorp Vault Secrets Manager.
vault_secrets_manager
Implementation of the HashiCorp Vault Secrets Manager integration.
VaultSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the Vault secrets manager - Key/value Engine.
Attributes:
Name | Type | Description |
---|---|---|
url |
str |
The url of the Vault server. |
token |
str |
The token to use to authenticate with Vault. |
cert |
Optional[str] |
The path to the certificate to use to authenticate with Vault. |
verify |
Optional[str] |
Whether to verify the certificate or not. |
mount_point |
str |
The mount point of the secrets manager. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
class VaultSecretsManager(BaseSecretsManager):
"""Class to interact with the Vault secrets manager - Key/value Engine.
Attributes:
url: The url of the Vault server.
token: The token to use to authenticate with Vault.
cert: The path to the certificate to use to authenticate with Vault.
verify: Whether to verify the certificate or not.
mount_point: The mount point of the secrets manager.
"""
# Class configuration
FLAVOR: ClassVar[str] = VAULT_SECRETS_MANAGER_FLAVOR
SUPPORTS_SCOPING: ClassVar[bool] = True
CLIENT: ClassVar[Any] = None
url: str
token: str
mount_point: str
cert: Optional[str]
verify: Optional[str]
@classmethod
def _ensure_client_connected(cls, url: str, token: str) -> None:
"""Ensure the client is connected.
This function initializes the client if it is not initialized.
Args:
url: The url of the Vault server.
token: The token to use to authenticate with Vault.
"""
if cls.CLIENT is None:
# Create a Vault Secrets Manager client
cls.CLIENT = hvac.Client(
url=url,
token=token,
)
def _ensure_client_is_authenticated(self) -> None:
"""Ensure the client is authenticated.
Raises:
RuntimeError: If the client is not initialized or authenticated.
"""
self._ensure_client_connected(url=self.url, token=self.token)
if not self.CLIENT.is_authenticated():
raise RuntimeError(
"There was an error authenticating with Vault. Please check "
"your configuration."
)
else:
pass
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
cls.validate_secret_name_or_namespace(namespace)
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
"""Validate a secret name or namespace.
For compatibility across secret managers the secret names should contain
only alphanumeric characters and the characters /_+=.@-. The `/`
character is only used internally to delimit scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If the secret already exists.
"""
self._ensure_client_is_authenticated()
self.validate_secret_name_or_namespace(secret.name)
try:
self.get_secret(secret.name)
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already " f"exists."
)
except KeyError:
pass
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.mount_point,
secret=secret_value,
)
logger.info("Created secret: %s", f"{secret_path}")
logger.info("Added value to secret.")
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
try:
secret_items = (
self.CLIENT.secrets.kv.v2.read_secret_version(
path=secret_path,
mount_point=self.mount_point,
)
.get("data", {})
.get("data", {})
)
except InvalidPath as e:
raise KeyError(e)
zenml_schema_name = secret_items.pop(ZENML_SCHEMA_NAME)
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
secret_items["name"] = secret_name
return secret_schema(**secret_items)
def get_all_secret_keys(self) -> List[str]:
"""List all secrets in Vault without any reformatting.
This function tries to get all secrets from Vault and returns
them as a list of strings (all secrets' names).
Returns:
A list of all secrets in the secrets manager.
"""
self._ensure_client_is_authenticated()
set_of_secrets: Set[str] = set()
secret_path = "/".join(self._get_scope_path())
try:
secrets = self.CLIENT.secrets.kv.v2.list_secrets(
path=secret_path, mount_point=self.mount_point
)
except hvac.exceptions.InvalidPath:
logger.error(
f"There are no secrets created within the scope `{secret_path}`"
)
return list(set_of_secrets)
secrets_keys = secrets.get("data", {}).get("keys", [])
for secret_key in secrets_keys:
# vault scopes end with / and are not themselves secrets
if "/" not in secret_key:
set_of_secrets.add(secret_key)
return list(set_of_secrets)
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
self.validate_secret_name_or_namespace(secret.name)
if secret.name in self.get_all_secret_keys():
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.mount_point,
secret=secret_value,
)
else:
raise KeyError(
f"A Secret with the name '{secret.name}'" f" does not exist."
)
logger.info("Updated secret: %s", secret_path)
logger.info("Added value to secret.")
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
path=secret_path,
mount_point=self.mount_point,
)
logger.info("Deleted secret: %s", f"{secret_path}")
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_is_authenticated()
for secret_name in self.get_all_secret_keys():
self.delete_secret(secret_name)
logger.info("Deleted all secrets.")
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_is_authenticated()
for secret_name in self.get_all_secret_keys():
self.delete_secret(secret_name)
logger.info("Deleted all secrets.")
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to delete. |
required |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
path=secret_path,
mount_point=self.mount_point,
)
logger.info("Deleted secret: %s", f"{secret_path}")
get_all_secret_keys(self)
List all secrets in Vault without any reformatting.
This function tries to get all secrets from Vault and returns them as a list of strings (all secrets' names).
Returns:
Type | Description |
---|---|
List[str] |
A list of all secrets in the secrets manager. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""List all secrets in Vault without any reformatting.
This function tries to get all secrets from Vault and returns
them as a list of strings (all secrets' names).
Returns:
A list of all secrets in the secrets manager.
"""
self._ensure_client_is_authenticated()
set_of_secrets: Set[str] = set()
secret_path = "/".join(self._get_scope_path())
try:
secrets = self.CLIENT.secrets.kv.v2.list_secrets(
path=secret_path, mount_point=self.mount_point
)
except hvac.exceptions.InvalidPath:
logger.error(
f"There are no secrets created within the scope `{secret_path}`"
)
return list(set_of_secrets)
secrets_keys = secrets.get("data", {}).get("keys", [])
for secret_key in secrets_keys:
# vault scopes end with / and are not themselves secrets
if "/" not in secret_key:
set_of_secrets.add(secret_key)
return list(set_of_secrets)
get_secret(self, secret_name)
Gets the value of a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to get. |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the secret does not exist. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
try:
secret_items = (
self.CLIENT.secrets.kv.v2.read_secret_version(
path=secret_path,
mount_point=self.mount_point,
)
.get("data", {})
.get("data", {})
)
except InvalidPath as e:
raise KeyError(e)
zenml_schema_name = secret_items.pop(ZENML_SCHEMA_NAME)
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
secret_items["name"] = secret_name
return secret_schema(**secret_items)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to register. |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
If the secret already exists. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If the secret already exists.
"""
self._ensure_client_is_authenticated()
self.validate_secret_name_or_namespace(secret.name)
try:
self.get_secret(secret.name)
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already " f"exists."
)
except KeyError:
pass
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.mount_point,
secret=secret_value,
)
logger.info("Created secret: %s", f"{secret_path}")
logger.info("Added value to secret.")
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to update. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If the secret does not exist. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
self.validate_secret_name_or_namespace(secret.name)
if secret.name in self.get_all_secret_keys():
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.mount_point,
secret=secret_value,
)
else:
raise KeyError(
f"A Secret with the name '{secret.name}'" f" does not exist."
)
logger.info("Updated secret: %s", secret_path)
logger.info("Added value to secret.")
validate_secret_name_or_namespace(name)
classmethod
Validate a secret name or namespace.
For compatibility across secret managers the secret names should contain
only alphanumeric characters and the characters /_+=.@-. The /
character is only used internally to delimit scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
"""Validate a secret name or namespace.
For compatibility across secret managers the secret names should contain
only alphanumeric characters and the characters /_+=.@-. The `/`
character is only used internally to delimit scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
wandb
special
Initialization for the wandb integration.
The wandb integrations currently enables you to use wandb tracking as a convenient way to visualize your experiment runs within the wandb ui.
WandbIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/wandb/__init__.py
class WandbIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = WANDB
REQUIREMENTS = ["wandb>=0.12.12", "Pillow>=9.1.0"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Weights and Biases integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=WANDB_EXPERIMENT_TRACKER_FLAVOR,
source="zenml.integrations.wandb.experiment_trackers.WandbExperimentTracker",
type=StackComponentType.EXPERIMENT_TRACKER,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the Weights and Biases integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/wandb/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Weights and Biases integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=WANDB_EXPERIMENT_TRACKER_FLAVOR,
source="zenml.integrations.wandb.experiment_trackers.WandbExperimentTracker",
type=StackComponentType.EXPERIMENT_TRACKER,
integration=cls.NAME,
)
]
experiment_trackers
special
Initialization for the wandb experiment tracker.
wandb_experiment_tracker
Implementation for the wandb experiment tracker.
WandbExperimentTracker (BaseExperimentTracker)
pydantic-model
Stores wandb configuration options.
ZenML should take care of configuring wandb for you, but should you still need access to the configuration inside your step you can do it using a step context:
from zenml.steps import StepContext
@enable_wandb
@step
def my_step(context: StepContext, ...)
context.stack.experiment_tracker # get the tracking_uri etc. from here
Attributes:
Name | Type | Description |
---|---|---|
entity |
Optional[str] |
Name of an existing wandb entity. |
project_name |
Optional[str] |
Name of an existing wandb project to log to. |
api_key |
str |
API key to should be authorized to log to the configured wandb entity and project. |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
class WandbExperimentTracker(BaseExperimentTracker):
"""Stores wandb configuration options.
ZenML should take care of configuring wandb for you, but should you still
need access to the configuration inside your step you can do it using a
step context:
```python
from zenml.steps import StepContext
@enable_wandb
@step
def my_step(context: StepContext, ...)
context.stack.experiment_tracker # get the tracking_uri etc. from here
```
Attributes:
entity: Name of an existing wandb entity.
project_name: Name of an existing wandb project to log to.
api_key: API key to should be authorized to log to the configured wandb
entity and project.
"""
api_key: str
entity: Optional[str] = None
project_name: Optional[str] = None
# Class Configuration
FLAVOR: ClassVar[str] = WANDB_EXPERIMENT_TRACKER_FLAVOR
def prepare_step_run(self) -> None:
"""Sets the wandb api key."""
os.environ[WANDB_API_KEY] = self.api_key
@contextmanager
def activate_wandb_run(
self,
run_name: str,
tags: Tuple[str, ...] = (),
settings: Optional[wandb.Settings] = None,
) -> Iterator[None]:
"""Activates a wandb run for the duration of this context manager.
Anything logged to wandb that is run while this context manager is
active will automatically log to the same wandb run configured by the
run name passed as an argument to this function.
Args:
run_name: Name of the wandb run to create.
tags: Tags to attach to the wandb run.
settings: Additional settings for the wandb run.
Yields:
None
"""
try:
logger.info(
f"Initializing wandb with project name: {self.project_name}, "
f"run_name: {run_name}, entity: {self.entity}."
)
wandb.init(
project=self.project_name,
name=run_name,
entity=self.entity,
settings=settings,
tags=tags,
)
yield
finally:
wandb.finish()
activate_wandb_run(self, run_name, tags=(), settings=None)
Activates a wandb run for the duration of this context manager.
Anything logged to wandb that is run while this context manager is active will automatically log to the same wandb run configured by the run name passed as an argument to this function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the wandb run to create. |
required |
tags |
Tuple[str, ...] |
Tags to attach to the wandb run. |
() |
settings |
Optional[wandb.sdk.wandb_settings.Settings] |
Additional settings for the wandb run. |
None |
Yields:
Type | Description |
---|---|
Iterator[NoneType] |
None |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
@contextmanager
def activate_wandb_run(
self,
run_name: str,
tags: Tuple[str, ...] = (),
settings: Optional[wandb.Settings] = None,
) -> Iterator[None]:
"""Activates a wandb run for the duration of this context manager.
Anything logged to wandb that is run while this context manager is
active will automatically log to the same wandb run configured by the
run name passed as an argument to this function.
Args:
run_name: Name of the wandb run to create.
tags: Tags to attach to the wandb run.
settings: Additional settings for the wandb run.
Yields:
None
"""
try:
logger.info(
f"Initializing wandb with project name: {self.project_name}, "
f"run_name: {run_name}, entity: {self.entity}."
)
wandb.init(
project=self.project_name,
name=run_name,
entity=self.entity,
settings=settings,
tags=tags,
)
yield
finally:
wandb.finish()
prepare_step_run(self)
Sets the wandb api key.
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def prepare_step_run(self) -> None:
"""Sets the wandb api key."""
os.environ[WANDB_API_KEY] = self.api_key
wandb_step_decorator
Implementation for the wandb step decorator.
enable_wandb(_step=None, *, settings=None)
Decorator to enable wandb for a step function.
Apply this decorator to a ZenML pipeline step to enable wandb experiment
tracking. The wandb tracking configuration (project name, experiment name,
entity) will be automatically configured before the step code is executed,
so the step can simply use the wandb
module to log metrics and artifacts,
like so:
@enable_wandb
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
wandb.log_metric("val_accuracy", test_acc)
return test_acc
You can also use this decorator with our class-based API like so:
@enable_wandb
class TFEvaluator(BaseStep):
def entrypoint(
self,
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
...
All wandb artifacts and metrics logged from all the steps in a pipeline
run are by default grouped under a single experiment named after the
pipeline. To log wandb artifacts and metrics from a step in a separate
wandb experiment, pass a custom experiment_name
argument value to the
decorator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_step |
Optional[~S] |
The decorated step class. |
None |
settings |
Optional[wandb.sdk.wandb_settings.Settings] |
wandb settings to use for the step. |
None |
Returns:
Type | Description |
---|---|
Union[~S, Callable[[~S], ~S]] |
The inner decorator which enhances the input step class with wandb tracking functionality |
Source code in zenml/integrations/wandb/wandb_step_decorator.py
def enable_wandb(
_step: Optional[S] = None, *, settings: Optional[wandb.Settings] = None
) -> Union[S, Callable[[S], S]]:
"""Decorator to enable wandb for a step function.
Apply this decorator to a ZenML pipeline step to enable wandb experiment
tracking. The wandb tracking configuration (project name, experiment name,
entity) will be automatically configured before the step code is executed,
so the step can simply use the `wandb` module to log metrics and artifacts,
like so:
```python
@enable_wandb
@step
def tf_evaluator(
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
_, test_acc = model.evaluate(x_test, y_test, verbose=2)
wandb.log_metric("val_accuracy", test_acc)
return test_acc
```
You can also use this decorator with our class-based API like so:
```
@enable_wandb
class TFEvaluator(BaseStep):
def entrypoint(
self,
x_test: np.ndarray,
y_test: np.ndarray,
model: tf.keras.Model,
) -> float:
...
```
All wandb artifacts and metrics logged from all the steps in a pipeline
run are by default grouped under a single experiment named after the
pipeline. To log wandb artifacts and metrics from a step in a separate
wandb experiment, pass a custom `experiment_name` argument value to the
decorator.
Args:
_step: The decorated step class.
settings: wandb settings to use for the step.
Returns:
The inner decorator which enhances the input step class with wandb
tracking functionality
"""
def inner_decorator(_step: S) -> S:
"""Inner decorator for step enable_wandb.
Args:
_step: The decorated step class.
Returns:
The decorated step class.
Raises:
RuntimeError: If the decorator is not being applied to a ZenML step
decorated function or a BaseStep subclass.
"""
logger.debug(
"Applying 'enable_wandb' decorator to step %s", _step.__name__
)
if not issubclass(_step, BaseStep):
raise RuntimeError(
"The `enable_wandb` decorator can only be applied to a ZenML "
"`step` decorated function or a BaseStep subclass."
)
source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
new_entrypoint = wandb_step_entrypoint(
settings=settings,
)(source_fn)
if _step._created_by_functional_api():
# If the step was created by the functional API, the old entrypoint
# was a static method -> make sure the new one is as well
new_entrypoint = staticmethod(new_entrypoint)
setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
return _step
if _step is None:
return inner_decorator
else:
return inner_decorator(_step)
wandb_step_entrypoint(settings=None)
Decorator for a step entrypoint to enable wandb.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
settings |
Optional[wandb.sdk.wandb_settings.Settings] |
wandb settings to use for the step. |
None |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
the input function enhanced with wandb profiling functionality |
Source code in zenml/integrations/wandb/wandb_step_decorator.py
def wandb_step_entrypoint(
settings: Optional[wandb.Settings] = None,
) -> Callable[[F], F]:
"""Decorator for a step entrypoint to enable wandb.
Args:
settings: wandb settings to use for the step.
Returns:
the input function enhanced with wandb profiling functionality
"""
def inner_decorator(func: F) -> F:
"""Inner decorator for step entrypoint.
Args:
func: The decorated function.
Returns:
the input function enhanced with wandb profiling functionality
"""
logger.debug(
"Applying 'wandb_step_entrypoint' decorator to step entrypoint %s",
func.__name__,
)
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa
"""Wrapper function for decorator.
Args:
*args: positional arguments to the decorated function.
**kwargs: keyword arguments to the decorated function.
Returns:
The return value of the decorated function.
Raises:
ValueError: if the active stack has no active experiment tracker.
"""
logger.debug(
"Setting up wandb backend before running step entrypoint %s",
func.__name__,
)
step_env = Environment().step_environment
run_name = f"{step_env.pipeline_run_id}_{step_env.step_name}"
tags = (step_env.pipeline_name, step_env.pipeline_run_id)
experiment_tracker = Repository( # type: ignore[call-arg]
skip_repository_check=True
).active_stack.experiment_tracker
if not isinstance(experiment_tracker, WandbExperimentTracker):
raise ValueError(
"The active stack needs to have a wandb experiment tracker "
"component registered to be able to track experiments "
"using wandb. You can create a new stack with a wandb "
"experiment tracker component or update your existing "
"stack to add this component, e.g.:\n\n"
" 'zenml experiment-tracker register wandb_tracker "
"--type=wandb --entity=<WANDB_ENTITY> --project_name="
"<WANDB_PROJECT_NAME> --api_key=<WANDB_API_KEY>'\n"
" 'zenml stack register stack-name -e wandb_tracker ...'\n"
)
with experiment_tracker.activate_wandb_run(
run_name=run_name,
tags=tags,
settings=settings,
):
return func(*args, **kwargs)
return cast(F, wrapper)
return inner_decorator
whylogs
special
Initialization of the whylogs integration.
WhylogsIntegration (Integration)
Definition of whylogs integration for ZenML.
Source code in zenml/integrations/whylogs/__init__.py
class WhylogsIntegration(Integration):
"""Definition of [whylogs](https://github.com/whylabs/whylogs) integration for ZenML."""
NAME = WHYLOGS
REQUIREMENTS = ["whylogs[viz]~=1.0.5", "whylogs[whylabs]~=1.0.5"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.whylogs import materializers # noqa
from zenml.integrations.whylogs import secret_schemas # noqa
from zenml.integrations.whylogs import visualizers # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=WHYLOGS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.whylogs.data_validators.WhylogsDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.whylogs import materializers # noqa
from zenml.integrations.whylogs import secret_schemas # noqa
from zenml.integrations.whylogs import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Great Expectations integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=WHYLOGS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.whylogs.data_validators.WhylogsDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
data_validators
special
Initialization of the whylogs data validator for ZenML.
whylogs_data_validator
Implementation of the whylogs data validator.
WhylogsDataValidator (BaseDataValidator, AuthenticationMixin)
pydantic-model
Whylogs data validator stack component.
Attributes:
Name | Type | Description |
---|---|---|
authentication_secret |
Optional[str] |
Optional ZenML secret with Whylabs credentials. If configured, all the data profiles returned by all pipeline steps will automatically be uploaded to Whylabs in addition to being stored in the ZenML Artifact Store. |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
class WhylogsDataValidator(BaseDataValidator, AuthenticationMixin):
"""Whylogs data validator stack component.
Attributes:
authentication_secret: Optional ZenML secret with Whylabs credentials.
If configured, all the data profiles returned by all pipeline steps
will automatically be uploaded to Whylabs in addition to being
stored in the ZenML Artifact Store.
"""
# Class Configuration
FLAVOR: ClassVar[str] = WHYLOGS_DATA_VALIDATOR_FLAVOR
NAME: ClassVar[str] = "whylogs"
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
**kwargs: Any,
) -> DatasetProfileView:
"""Analyze a dataset and generate a data profile with whylogs.
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of whylogs
data profiles to be generated (unused).
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
**kwargs: Extra keyword arguments (unused).
Returns:
A whylogs profile view object.
"""
results = why.log(pandas=dataset)
profile = results.profile()
dataset_timestamp = dataset_timestamp or datetime.datetime.utcnow()
profile.set_dataset_timestamp(dataset_timestamp=dataset_timestamp)
return profile.view()
def upload_profile_view(
self, profile_view: DatasetProfileView, dataset_id: Optional[str] = None
) -> None:
"""Upload a whylogs data profile view to Whylabs, if configured to do so.
Args:
profile_view: Whylogs profile view to upload.
dataset_id: Optional dataset identifier to use for the uploaded
data profile. If omitted, a dataset identifier will be retrieved
using other means, in order:
* the default dataset identifier configured in the Data
Validator secret
* a dataset ID will be generated automatically based on the
current pipeline/step information.
Raises:
ValueError: If the dataset ID was not provided and could not be
retrieved or inferred from other sources.
"""
secret = self.get_authentication_secret(
expected_schema_type=WhylabsSecretSchema
)
if not secret:
return
dataset_id = dataset_id or secret.whylabs_default_dataset_id
if not dataset_id:
# use the current pipeline name and the step name to generate a
# unique dataset name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
dataset_id = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
raise ValueError(
"A dataset ID was not specified and could not be "
"generated from the current pipeline and step name."
)
# Instantiate WhyLabs Writer
writer = WhyLabsWriter(
org_id=secret.whylabs_default_org_id,
api_key=secret.whylabs_api_key,
dataset_id=dataset_id,
)
# pass a profile view to the writer's write method
writer.write(profile=profile_view)
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, dataset_timestamp=None, **kwargs)
Analyze a dataset and generate a data profile with whylogs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
Target dataset to be profiled. |
required |
comparison_dataset |
Optional[pandas.core.frame.DataFrame] |
Optional dataset to be used for data profiles that require a baseline for comparison (e.g data drift profiles). |
None |
profile_list |
Optional[Sequence[str]] |
Optional list identifying the categories of whylogs data profiles to be generated (unused). |
None |
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
None |
**kwargs |
Any |
Extra keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
DatasetProfileView |
A whylogs profile view object. |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
**kwargs: Any,
) -> DatasetProfileView:
"""Analyze a dataset and generate a data profile with whylogs.
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of whylogs
data profiles to be generated (unused).
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
**kwargs: Extra keyword arguments (unused).
Returns:
A whylogs profile view object.
"""
results = why.log(pandas=dataset)
profile = results.profile()
dataset_timestamp = dataset_timestamp or datetime.datetime.utcnow()
profile.set_dataset_timestamp(dataset_timestamp=dataset_timestamp)
return profile.view()
upload_profile_view(self, profile_view, dataset_id=None)
Upload a whylogs data profile view to Whylabs, if configured to do so.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile_view |
DatasetProfileView |
Whylogs profile view to upload. |
required |
dataset_id |
Optional[str] |
Optional dataset identifier to use for the uploaded data profile. If omitted, a dataset identifier will be retrieved using other means, in order: * the default dataset identifier configured in the Data Validator secret * a dataset ID will be generated automatically based on the current pipeline/step information. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset ID was not provided and could not be retrieved or inferred from other sources. |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def upload_profile_view(
self, profile_view: DatasetProfileView, dataset_id: Optional[str] = None
) -> None:
"""Upload a whylogs data profile view to Whylabs, if configured to do so.
Args:
profile_view: Whylogs profile view to upload.
dataset_id: Optional dataset identifier to use for the uploaded
data profile. If omitted, a dataset identifier will be retrieved
using other means, in order:
* the default dataset identifier configured in the Data
Validator secret
* a dataset ID will be generated automatically based on the
current pipeline/step information.
Raises:
ValueError: If the dataset ID was not provided and could not be
retrieved or inferred from other sources.
"""
secret = self.get_authentication_secret(
expected_schema_type=WhylabsSecretSchema
)
if not secret:
return
dataset_id = dataset_id or secret.whylabs_default_dataset_id
if not dataset_id:
# use the current pipeline name and the step name to generate a
# unique dataset name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
dataset_id = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
raise ValueError(
"A dataset ID was not specified and could not be "
"generated from the current pipeline and step name."
)
# Instantiate WhyLabs Writer
writer = WhyLabsWriter(
org_id=secret.whylabs_default_org_id,
api_key=secret.whylabs_api_key,
dataset_id=dataset_id,
)
# pass a profile view to the writer's write method
writer.write(profile=profile_view)
materializers
special
Initialization of the whylogs materializer.
whylogs_materializer
Implementation of the whylogs materializer.
WhylogsMaterializer (BaseMaterializer)
Materializer to read/write whylogs dataset profile views.
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
class WhylogsMaterializer(BaseMaterializer):
"""Materializer to read/write whylogs dataset profile views."""
ASSOCIATED_TYPES = (DatasetProfileView,)
ASSOCIATED_ARTIFACT_TYPES = (StatisticsArtifact,)
def handle_input(self, data_type: Type[Any]) -> DatasetProfileView:
"""Reads and returns a whylogs dataset profile view.
Args:
data_type: The type of the data to read.
Returns:
A loaded whylogs dataset profile view object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
profile_view = DatasetProfileView.read(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return profile_view
def handle_return(self, profile_view: DatasetProfileView) -> None:
"""Writes a whylogs dataset profile view.
Args:
profile_view: A whylogs dataset profile view object.
"""
super().handle_return(profile_view)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
profile_view.write(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
# Use the data validator to upload the profile view to Whylabs,
# if configured to do so. This logic is only enabled if the pipeline
# step was decorated with the `enable_whylabs` decorator
whylabs_enabled = os.environ.get(WHYLABS_LOGGING_ENABLED_ENV)
if not whylabs_enabled:
return
dataset_id = os.environ.get(WHYLABS_DATASET_ID_ENV)
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
data_validator.upload_profile_view(profile_view, dataset_id=dataset_id)
handle_input(self, data_type)
Reads and returns a whylogs dataset profile view.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
DatasetProfileView |
A loaded whylogs dataset profile view object. |
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_input(self, data_type: Type[Any]) -> DatasetProfileView:
"""Reads and returns a whylogs dataset profile view.
Args:
data_type: The type of the data to read.
Returns:
A loaded whylogs dataset profile view object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
profile_view = DatasetProfileView.read(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return profile_view
handle_return(self, profile_view)
Writes a whylogs dataset profile view.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile_view |
DatasetProfileView |
A whylogs dataset profile view object. |
required |
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_return(self, profile_view: DatasetProfileView) -> None:
"""Writes a whylogs dataset profile view.
Args:
profile_view: A whylogs dataset profile view object.
"""
super().handle_return(profile_view)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
profile_view.write(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
# Use the data validator to upload the profile view to Whylabs,
# if configured to do so. This logic is only enabled if the pipeline
# step was decorated with the `enable_whylabs` decorator
whylabs_enabled = os.environ.get(WHYLABS_LOGGING_ENABLED_ENV)
if not whylabs_enabled:
return
dataset_id = os.environ.get(WHYLABS_DATASET_ID_ENV)
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
data_validator.upload_profile_view(profile_view, dataset_id=dataset_id)
secret_schemas
special
Initialization for the Whylabs secret schema.
This schema can be used to configure a ZenML secret to authenticate ZenML to use the Whylabs platform to automatically log all whylogs data profiles generated and by pipeline steps.
whylabs_secret_schema
Implementation for Seldon secret schemas.
WhylabsSecretSchema (BaseSecretSchema)
pydantic-model
Whylabs credentials.
Attributes:
Name | Type | Description |
---|---|---|
whylabs_default_org_id |
str |
the Whylabs organization ID. |
whylabs_api_key |
str |
Whylabs API key. |
whylabs_default_dataset_id |
Optional[str] |
default Whylabs dataset ID to use when logging data profiles. |
Source code in zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py
class WhylabsSecretSchema(BaseSecretSchema):
"""Whylabs credentials.
Attributes:
whylabs_default_org_id: the Whylabs organization ID.
whylabs_api_key: Whylabs API key.
whylabs_default_dataset_id: default Whylabs dataset ID to use when
logging data profiles.
"""
TYPE: ClassVar[str] = WHYLABS_SECRET_SCHEMA_TYPE
whylabs_default_org_id: str
whylabs_api_key: str
whylabs_default_dataset_id: Optional[str] = None
steps
special
Initialization of the whylogs steps.
whylogs_profiler
Implementation of the whylogs profiler step.
WhylogsProfilerConfig (BaseAnalyzerConfig)
pydantic-model
Config class for the WhylogsProfiler step.
Attributes:
Name | Type | Description |
---|---|---|
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
"""Config class for the WhylogsProfiler step.
Attributes:
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
"""
dataset_timestamp: Optional[datetime.datetime]
WhylogsProfilerStep (BaseAnalyzerStep)
Generates a whylogs data profile from a given pd.DataFrame.
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerStep(BaseAnalyzerStep):
"""Generates a whylogs data profile from a given pd.DataFrame."""
@staticmethod
def entrypoint( # type: ignore[override]
dataset: pd.DataFrame,
config: WhylogsProfilerConfig,
) -> DatasetProfileView:
"""Main entrypoint function for the whylogs profiler.
Args:
dataset: pd.DataFrame, the given dataset
config: the configuration of the step
Returns:
whylogs profile with statistics generated for the input dataset
"""
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
return data_validator.data_profiling(
dataset, dataset_timestamp=config.dataset_timestamp
)
CONFIG_CLASS (BaseAnalyzerConfig)
pydantic-model
Config class for the WhylogsProfiler step.
Attributes:
Name | Type | Description |
---|---|---|
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
"""Config class for the WhylogsProfiler step.
Attributes:
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
"""
dataset_timestamp: Optional[datetime.datetime]
entrypoint(dataset, config)
staticmethod
Main entrypoint function for the whylogs profiler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
pd.DataFrame, the given dataset |
required |
config |
WhylogsProfilerConfig |
the configuration of the step |
required |
Returns:
Type | Description |
---|---|
DatasetProfileView |
whylogs profile with statistics generated for the input dataset |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
@staticmethod
def entrypoint( # type: ignore[override]
dataset: pd.DataFrame,
config: WhylogsProfilerConfig,
) -> DatasetProfileView:
"""Main entrypoint function for the whylogs profiler.
Args:
dataset: pd.DataFrame, the given dataset
config: the configuration of the step
Returns:
whylogs profile with statistics generated for the input dataset
"""
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
return data_validator.data_profiling(
dataset, dataset_timestamp=config.dataset_timestamp
)
whylogs_profiler_step(step_name, config, dataset_id=None)
Shortcut function to create a new instance of the WhylogsProfilerStep step.
The returned WhylogsProfilerStep can be used in a pipeline to generate a whylogs DatasetProfileView from a given pd.DataFrame and save it as an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
WhylogsProfilerConfig |
The step configuration |
required |
dataset_id |
Optional[str] |
Optional dataset ID to use to upload the profile to Whylabs. |
None |
Returns:
Type | Description |
---|---|
BaseStep |
a WhylogsProfilerStep step instance |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def whylogs_profiler_step(
step_name: str,
config: WhylogsProfilerConfig,
dataset_id: Optional[str] = None,
) -> BaseStep:
"""Shortcut function to create a new instance of the WhylogsProfilerStep step.
The returned WhylogsProfilerStep can be used in a pipeline to generate a
whylogs DatasetProfileView from a given pd.DataFrame and save it as an
artifact.
Args:
step_name: The name of the step
config: The step configuration
dataset_id: Optional dataset ID to use to upload the profile to Whylabs.
Returns:
a WhylogsProfilerStep step instance
"""
step = enable_whylabs(dataset_id=dataset_id)(
clone_step(WhylogsProfilerStep, step_name)
)
return step(config=config)
visualizers
special
Initialization of the whylogs visualizer.
whylogs_visualizer
Implementation of the whylogs visualizer step.
WhylogsVisualizer (BaseStepVisualizer)
The implementation of a Whylogs Visualizer.
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsVisualizer(BaseStepVisualizer):
"""The implementation of a Whylogs Visualizer."""
def visualize(
self,
object: StepView,
reference_step_view: Optional[StepView] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Visualize whylogs dataset profiles present as outputs in the step view.
Args:
object: StepView fetched from run.get_step().
reference_step_view: second StepView fetched from run.get_step() to
use as a reference to visualize data drift
*args: additional positional arguments to pass to the visualize
method
**kwargs: additional keyword arguments to pass to the visualize
method
"""
def extract_profile(
step_view: StepView,
) -> Optional[DatasetProfileView]:
"""Extract a whylogs DatasetProfileView from a step view.
Args:
step_view: a step view
Returns:
A whylogs DatasetProfileView object loaded from the step view,
if one could be found, otherwise None.
"""
whylogs_artifact_datatype = (
f"{DatasetProfileView.__module__}.{DatasetProfileView.__name__}"
)
for _, artifact_view in step_view.outputs.items():
# filter out anything but whylogs dataset profile artifacts
if artifact_view.data_type == whylogs_artifact_datatype:
profile = artifact_view.read()
return cast(DatasetProfileView, profile)
return None
profile = extract_profile(object)
reference_profile: Optional[DatasetProfileView] = None
if reference_step_view:
reference_profile = extract_profile(reference_step_view)
self.visualize_profile(profile, reference_profile)
def visualize_profile(
self,
profile: DatasetProfileView,
reference_profile: Optional[DatasetProfileView] = None,
) -> None:
"""Generate a visualization of one or two whylogs dataset profile.
Args:
profile: whylogs DatasetProfileView to visualize
reference_profile: second optional DatasetProfileView to use to
generate a data drift visualization
"""
# currently, whylogs doesn't support visualizing a single profile, so
# we trick it by using the same profile twice, both as reference and
# target, in a drift report
reference_profile = reference_profile or profile
visualization = NotebookProfileVisualizer()
visualization.set_profiles(
target_profile_view=profile,
reference_profile_view=reference_profile,
)
rendered_html = visualization.summary_drift_report()
if Environment.in_notebook():
from IPython.core.display import display
display(rendered_html)
for column in sorted(list(profile.get_columns().keys())):
display(visualization.double_histogram(feature_name=column))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(rendered_html.data)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, reference_step_view=None, *args, **kwargs)
Visualize whylogs dataset profiles present as outputs in the step view.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
reference_step_view |
Optional[zenml.post_execution.step.StepView] |
second StepView fetched from run.get_step() to use as a reference to visualize data drift |
None |
*args |
Any |
additional positional arguments to pass to the visualize method |
() |
**kwargs |
Any |
additional keyword arguments to pass to the visualize method |
{} |
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize(
self,
object: StepView,
reference_step_view: Optional[StepView] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Visualize whylogs dataset profiles present as outputs in the step view.
Args:
object: StepView fetched from run.get_step().
reference_step_view: second StepView fetched from run.get_step() to
use as a reference to visualize data drift
*args: additional positional arguments to pass to the visualize
method
**kwargs: additional keyword arguments to pass to the visualize
method
"""
def extract_profile(
step_view: StepView,
) -> Optional[DatasetProfileView]:
"""Extract a whylogs DatasetProfileView from a step view.
Args:
step_view: a step view
Returns:
A whylogs DatasetProfileView object loaded from the step view,
if one could be found, otherwise None.
"""
whylogs_artifact_datatype = (
f"{DatasetProfileView.__module__}.{DatasetProfileView.__name__}"
)
for _, artifact_view in step_view.outputs.items():
# filter out anything but whylogs dataset profile artifacts
if artifact_view.data_type == whylogs_artifact_datatype:
profile = artifact_view.read()
return cast(DatasetProfileView, profile)
return None
profile = extract_profile(object)
reference_profile: Optional[DatasetProfileView] = None
if reference_step_view:
reference_profile = extract_profile(reference_step_view)
self.visualize_profile(profile, reference_profile)
visualize_profile(self, profile, reference_profile=None)
Generate a visualization of one or two whylogs dataset profile.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile |
DatasetProfileView |
whylogs DatasetProfileView to visualize |
required |
reference_profile |
Optional[whylogs.core.view.dataset_profile_view.DatasetProfileView] |
second optional DatasetProfileView to use to generate a data drift visualization |
None |
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize_profile(
self,
profile: DatasetProfileView,
reference_profile: Optional[DatasetProfileView] = None,
) -> None:
"""Generate a visualization of one or two whylogs dataset profile.
Args:
profile: whylogs DatasetProfileView to visualize
reference_profile: second optional DatasetProfileView to use to
generate a data drift visualization
"""
# currently, whylogs doesn't support visualizing a single profile, so
# we trick it by using the same profile twice, both as reference and
# target, in a drift report
reference_profile = reference_profile or profile
visualization = NotebookProfileVisualizer()
visualization.set_profiles(
target_profile_view=profile,
reference_profile_view=reference_profile,
)
rendered_html = visualization.summary_drift_report()
if Environment.in_notebook():
from IPython.core.display import display
display(rendered_html)
for column in sorted(list(profile.get_columns().keys())):
display(visualization.double_histogram(feature_name=column))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(rendered_html.data)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
whylabs_step_decorator
Implementation of the Whylabs step decorator.
enable_whylabs(_step=None, *, dataset_id=None)
Decorator to enable Whylabs profiling for a step function.
Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.
Note that you also need to have a whylogs Data Validator part of your active stack with the Whylabs credentials configured for this to have effect.
All the whylogs data profile views returned by the step will automatically be uploaded to the Whylabs platform if the active whylogs Data Validator component is configured with Whylabs credentials, e.g.:
import whylogs as why
from whylogs.core import DatasetProfileView
from zenml.integrations.whylogs.whylabs_step_decorator import enable_whylabs
@enable_whylabs(dataset_id="my_model")
@step
def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
...
data = pd.DataFrame(...)
results = why.log(pandas=dataset)
profile = results.profile()
...
return data, profile.view()
You can also use this decorator with our class-based API like so:
@enable_whylabs(dataset_id="my_model")
class DataLoader(BaseStep):
def entrypoint(self) -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
...
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_step |
Optional[~S] |
The decorated step class. |
None |
dataset_id |
Optional[str] |
Optional dataset ID to use for the uploaded profile(s) |
None |
Returns:
Type | Description |
---|---|
Union[~S, Callable[[~S], ~S]] |
the inner decorator which enhances the input step class with Whylabs profiling functionality |
Source code in zenml/integrations/whylogs/whylabs_step_decorator.py
def enable_whylabs(
_step: Optional[S] = None,
*,
dataset_id: Optional[str] = None,
) -> Union[S, Callable[[S], S]]:
"""Decorator to enable Whylabs profiling for a step function.
Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.
Note that you also need to have a whylogs Data Validator part of your active
stack with the Whylabs credentials configured for this to have effect.
All the whylogs data profile views returned by the step will automatically
be uploaded to the Whylabs platform if the active whylogs Data Validator
component is configured with Whylabs credentials, e.g.:
```python
import whylogs as why
from whylogs.core import DatasetProfileView
from zenml.integrations.whylogs.whylabs_step_decorator import enable_whylabs
@enable_whylabs(dataset_id="my_model")
@step
def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
...
data = pd.DataFrame(...)
results = why.log(pandas=dataset)
profile = results.profile()
...
return data, profile.view()
```
You can also use this decorator with our class-based API like so:
```
@enable_whylabs(dataset_id="my_model")
class DataLoader(BaseStep):
def entrypoint(self) -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
...
```
Args:
_step: The decorated step class.
dataset_id: Optional dataset ID to use for the uploaded profile(s)
Returns:
the inner decorator which enhances the input step class with Whylabs
profiling functionality
"""
def inner_decorator(_step: S) -> S:
source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
new_entrypoint = whylabs_entrypoint(dataset_id)(source_fn)
if _step._created_by_functional_api():
# If the step was created by the functional API, the old entrypoint
# was a static method -> make sure the new one is as well
new_entrypoint = staticmethod(new_entrypoint)
setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
return _step
if _step is None:
return inner_decorator
else:
return inner_decorator(_step)
whylabs_entrypoint(dataset_id=None)
Decorator for a step entrypoint to enable Whylabs logging.
Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.
Note that you also need to have a whylogs Data Validator part of your active stack with the Whylabs credentials configured for this to have effect.
All the whylogs data profile views returned by the step will automatically be uploaded to the Whylabs platform if the active whylogs Data Validator component is configured with Whylabs credentials, e.g.:
.. highlight:: python .. code-block:: python
import whylogs as why
from whylogs.core import DatasetProfileView
from zenml.integrations.whylogs.whylabs_step_decorator import whylabs_entrypoint
@step
@whylabs_entrypoint(dataset_id="my_model")
def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
...
data = pd.DataFrame(...)
results = why.log(pandas=dataset)
profile = results.profile()
...
return data, profile.view()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_id |
Optional[str] |
Optional dataset ID to use for the uploaded profile(s) |
None |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
the input function enhanced with Whylabs profiling functionality |
Source code in zenml/integrations/whylogs/whylabs_step_decorator.py
def whylabs_entrypoint(
dataset_id: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator for a step entrypoint to enable Whylabs logging.
Apply this decorator to a ZenML pipeline step to enable Whylabs profiling.
Note that you also need to have a whylogs Data Validator part of your active
stack with the Whylabs credentials configured for this to have effect.
All the whylogs data profile views returned by the step will automatically
be uploaded to the Whylabs platform if the active whylogs Data Validator
component is configured with Whylabs credentials, e.g.:
.. highlight:: python
.. code-block:: python
import whylogs as why
from whylogs.core import DatasetProfileView
from zenml.integrations.whylogs.whylabs_step_decorator import whylabs_entrypoint
@step
@whylabs_entrypoint(dataset_id="my_model")
def data_loader() -> Output(data=pd.DataFrame, profile=DatasetProfileView,):
...
data = pd.DataFrame(...)
results = why.log(pandas=dataset)
profile = results.profile()
...
return data, profile.view()
Args:
dataset_id: Optional dataset ID to use for the uploaded profile(s)
Returns:
the input function enhanced with Whylabs profiling functionality
"""
def inner_decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa
os.environ[WHYLABS_LOGGING_ENABLED_ENV] = "true"
if dataset_id:
os.environ[WHYLABS_DATASET_ID_ENV] = dataset_id
try:
return func(*args, **kwargs)
finally:
del os.environ[WHYLABS_LOGGING_ENABLED_ENV]
if dataset_id:
del os.environ[WHYLABS_DATASET_ID_ENV]
return cast(F, wrapper)
return inner_decorator
xgboost
special
Initialization of the XGBoost integration.
XgboostIntegration (Integration)
Definition of xgboost integration for ZenML.
Source code in zenml/integrations/xgboost/__init__.py
class XgboostIntegration(Integration):
"""Definition of xgboost integration for ZenML."""
NAME = XGBOOST
REQUIREMENTS = ["xgboost>=1.0.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.xgboost import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/xgboost/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.xgboost import materializers # noqa
materializers
special
Initialization of the XGBoost materializers.
xgboost_booster_materializer
Implementation of an XGBoost booster materializer.
XgboostBoosterMaterializer (BaseMaterializer)
Materializer to read data to and from xgboost.Booster.
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
class XgboostBoosterMaterializer(BaseMaterializer):
"""Materializer to read data to and from xgboost.Booster."""
ASSOCIATED_TYPES = (xgb.Booster,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
"""Reads a xgboost Booster model from a serialized JSON file.
Args:
data_type: A xgboost Booster type.
Returns:
A xgboost Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = xgb.Booster()
booster.load_model(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
def handle_return(self, booster: xgb.Booster) -> None:
"""Creates a JSON serialization for a xgboost Booster model.
Args:
booster: A xgboost Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
handle_input(self, data_type)
Reads a xgboost Booster model from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A xgboost Booster type. |
required |
Returns:
Type | Description |
---|---|
Booster |
A xgboost Booster object. |
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
"""Reads a xgboost Booster model from a serialized JSON file.
Args:
data_type: A xgboost Booster type.
Returns:
A xgboost Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = xgb.Booster()
booster.load_model(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
handle_return(self, booster)
Creates a JSON serialization for a xgboost Booster model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
booster |
Booster |
A xgboost Booster model. |
required |
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_return(self, booster: xgb.Booster) -> None:
"""Creates a JSON serialization for a xgboost Booster model.
Args:
booster: A xgboost Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
xgboost_dmatrix_materializer
Implementation of the XGBoost dmatrix materializer.
XgboostDMatrixMaterializer (BaseMaterializer)
Materializer to read data to and from xgboost.DMatrix.
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
class XgboostDMatrixMaterializer(BaseMaterializer):
"""Materializer to read data to and from xgboost.DMatrix."""
ASSOCIATED_TYPES = (xgb.DMatrix,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
"""Reads a xgboost.DMatrix binary file and loads it.
Args:
data_type: The datatype which should be read.
Returns:
Materialized xgboost matrix.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = xgb.DMatrix(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return matrix
def handle_return(self, matrix: xgb.DMatrix) -> None:
"""Creates a binary serialization for a xgboost.DMatrix object.
Args:
matrix: A xgboost.DMatrix object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
matrix.save_binary(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
handle_input(self, data_type)
Reads a xgboost.DMatrix binary file and loads it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The datatype which should be read. |
required |
Returns:
Type | Description |
---|---|
DMatrix |
Materialized xgboost matrix. |
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
"""Reads a xgboost.DMatrix binary file and loads it.
Args:
data_type: The datatype which should be read.
Returns:
Materialized xgboost matrix.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = xgb.DMatrix(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return matrix
handle_return(self, matrix)
Creates a binary serialization for a xgboost.DMatrix object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
matrix |
DMatrix |
A xgboost.DMatrix object. |
required |
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_return(self, matrix: xgb.DMatrix) -> None:
"""Creates a binary serialization for a xgboost.DMatrix object.
Args:
matrix: A xgboost.DMatrix object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
matrix.save_binary(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)