Integrations
zenml.integrations
special
ZenML integrations module.
The ZenML integrations module contains sub-modules for each integration that we
support. This includes orchestrators like Apache Airflow, visualization tools
like the facets
library, as well as deep learning libraries like PyTorch.
airflow
special
Airflow integration for ZenML.
The Airflow integration sub-module powers an alternative to the local
orchestrator. You can enable it by registering the Airflow orchestrator with
the CLI tool, then bootstrap using the zenml orchestrator up
command.
AirflowIntegration (Integration)
Definition of Airflow Integration for ZenML.
Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
"""Definition of Airflow Integration for ZenML."""
NAME = AIRFLOW
REQUIREMENTS = ["apache-airflow==2.2.0"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.airflow.flavors import AirflowOrchestratorFlavor
return [AirflowOrchestratorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Airflow integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.airflow.flavors import AirflowOrchestratorFlavor
return [AirflowOrchestratorFlavor]
flavors
special
Airflow integration flavors.
airflow_orchestrator_flavor
Airflow orchestrator flavor.
AirflowOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Airflow orchestrator.
Source code in zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py
class AirflowOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Airflow orchestrator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AIRFLOW_ORCHESTRATOR_FLAVOR
@property
def implementation_class(self) -> Type["AirflowOrchestrator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.airflow.orchestrators import AirflowOrchestrator
return AirflowOrchestrator
implementation_class: Type[AirflowOrchestrator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AirflowOrchestrator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
orchestrators
special
The Airflow integration enables the use of Airflow as a pipeline orchestrator.
airflow_orchestrator
Implementation of Airflow orchestrator integration.
AirflowOrchestrator (BaseOrchestrator)
Orchestrator responsible for running pipelines using Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Airflow."""
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self.airflow_home = os.path.join(
io_utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(self.id),
)
self._set_env()
@staticmethod
def _translate_schedule(
schedule: Optional[Schedule] = None,
) -> Dict[str, Any]:
"""Convert ZenML schedule into Airflow schedule.
The Airflow schedule uses slightly different naming and needs some
default entries for execution without a schedule.
Args:
schedule: Containing the interval, start and end date and
a boolean flag that defines if past runs should be caught up
on
Returns:
Airflow configuration dict.
"""
if schedule:
if schedule.cron_expression:
start_time = schedule.start_time or (
datetime.datetime.now() - datetime.timedelta(1)
)
return {
"schedule_interval": schedule.cron_expression,
"start_date": start_time,
"end_date": schedule.end_time,
"catchup": schedule.catchup,
}
else:
return {
"schedule_interval": schedule.interval_second,
"start_date": schedule.start_time,
"end_date": schedule.end_time,
"catchup": schedule.catchup,
}
return {
"schedule_interval": "@once",
# set the a start time in the past and disable catchup so airflow runs the dag immediately
"start_date": datetime.datetime.now() - datetime.timedelta(7),
"catchup": False,
}
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment
and used for orchestration of the pipeline.
How it works:
-------------
A new airflow_dag is instantiated with the pipeline name and among
others things the run schedule.
For each step of the pipeline a callable is created. This callable
uses the run_step() method to execute the step. The parameters of
this callable are pre-filled and an airflow step_operator is created
within the dag. The dependencies to upstream steps are then
configured.
Finally, the dag is fully complete and can be returned.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Returns:
The Airflow DAG.
"""
import airflow
from airflow.operators import python as airflow_python
# Instantiate and configure airflow Dag with name and schedule
airflow_dag = airflow.DAG(
dag_id=deployment.pipeline.name,
is_paused_upon_creation=False,
**self._translate_schedule(deployment.schedule),
)
# Dictionary mapping step names to airflow_operators. This will be needed
# to configure airflow operator dependencies
step_name_to_airflow_operator = {}
for step in deployment.steps.values():
# Create callable that will be used by airflow to execute the step
# within the orchestrated environment
def _step_callable(step_instance: "Step", **kwargs):
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Airflow orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
# Extract run name for the kwargs that will be passed to the
# callable
run_name = kwargs["ti"].get_dagrun().run_id
self._prepare_run(deployment=deployment)
self.run_step(step=step_instance, run_name=run_name)
self._cleanup_run()
# Create airflow python operator that contains the step callable
airflow_operator = airflow_python.PythonOperator(
dag=airflow_dag,
task_id=step.config.name,
provide_context=True,
python_callable=functools.partial(
_step_callable, step_instance=step
),
)
# Configure the current airflow operator to run after all upstream
# operators finished executing
step_name_to_airflow_operator[step.config.name] = airflow_operator
for upstream_step_name in step.spec.upstream_steps:
airflow_operator.set_upstream(
step_name_to_airflow_operator[upstream_step_name]
)
# Return the finished airflow dag
return airflow_dag
@property
def dags_directory(self) -> str:
"""Returns path to the airflow dags directory.
Returns:
Path to the airflow dags directory.
"""
return os.path.join(self.airflow_home, "dags")
@property
def pid_file(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.airflow_home, "airflow_daemon.pid")
@property
def log_file(self) -> str:
"""Returns path to the airflow log file.
Returns:
str: Path to the airflow log file.
"""
return os.path.join(self.airflow_home, "airflow_orchestrator.log")
@property
def password_file(self) -> str:
"""Returns path to the webserver password file.
Returns:
Path to the webserver password file.
"""
return os.path.join(self.airflow_home, "standalone_admin_password.txt")
def _set_env(self) -> None:
"""Sets environment variables to configure airflow."""
os.environ["AIRFLOW_HOME"] = self.airflow_home
os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
# check the DAG folder every 10 seconds for new files
os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"
def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
"""Copies DAG module to the Airflow DAGs directory if not already present.
Args:
dag_filepath: Path to the file in which the DAG is defined.
"""
dags_directory = io_utils.resolve_relative_path(self.dags_directory)
if dags_directory == os.path.dirname(dag_filepath):
logger.debug("File is already in airflow DAGs directory.")
else:
logger.debug(
"Copying dag file '%s' to DAGs directory.", dag_filepath
)
destination_path = os.path.join(
dags_directory, os.path.basename(dag_filepath)
)
if fileio.exists(destination_path):
logger.info(
"File '%s' already exists, overwriting with new DAG file",
destination_path,
)
fileio.copy(dag_filepath, destination_path, overwrite=True)
def _log_webserver_credentials(self) -> None:
"""Logs URL and credentials to log in to the airflow webserver.
Raises:
FileNotFoundError: If the password file does not exist.
"""
if fileio.exists(self.password_file):
with open(self.password_file) as file:
password = file.read().strip()
else:
raise FileNotFoundError(
f"Can't find password file '{self.password_file}'"
)
logger.info(
"To inspect your DAGs, login to http://localhost:8080 "
"with username: admin password: %s",
password,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Checks Airflow is running and copies DAG file to the DAGs directory.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
RuntimeError: If Airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
if Environment.in_notebook():
raise RuntimeError(
"Unable to run the Airflow orchestrator from within a "
"notebook. Airflow requires a python file which contains a "
"global Airflow DAG object and therefore does not work with "
"notebooks. Please copy your ZenML pipeline code in a python "
"file and try again."
)
try:
dag_filepath = deployment.pipeline.extra[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
@property
def is_running(self) -> bool:
"""Returns whether the airflow daemon is currently running.
Returns:
True if the daemon is running, False otherwise.
Raises:
RuntimeError: If port 8080 is occupied.
"""
from airflow.cli.commands.standalone_command import StandaloneCommand
from airflow.jobs.triggerer_job import TriggererJob
daemon_running = daemon.check_if_daemon_is_running(self.pid_file)
command = StandaloneCommand()
webserver_port_open = command.port_open(8080)
if not daemon_running:
if webserver_port_open:
raise RuntimeError(
"The airflow daemon does not seem to be running but "
"local port 8080 is occupied. Make sure the port is "
"available and try again."
)
# exit early so we don't check non-existing airflow databases
return False
# we can't use StandaloneCommand().is_ready() here as the
# Airflow SequentialExecutor apparently does not send a heartbeat
# while running a task which would result in this returning `False`
# even if Airflow is running.
airflow_running = webserver_port_open and command.job_running(
TriggererJob
)
return airflow_running
@property
def is_provisioned(self) -> bool:
"""Returns whether the airflow daemon is currently running.
Returns:
True if the airflow daemon is running, False otherwise.
"""
return self.is_running
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.exists(self.dags_directory):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Skip pipeline registration inside the airflow server process.
# When searching for DAGs, airflow imports the runner file in a
# randomly generated module. If we don't skip pipeline registration,
# it would fail by trying to register a pipeline with an existing
# name but different module sources for the steps.
with set_environment_variable(
key=ENV_ZENML_SKIP_PIPELINE_REGISTRATION, value="True"
):
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=get_source_root_path(),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rmtree(self.airflow_home)
logger.info("Airflow spun down.")
dags_directory: str
property
readonly
Returns path to the airflow dags directory.
Returns:
Type | Description |
---|---|
str |
Path to the airflow dags directory. |
is_provisioned: bool
property
readonly
Returns whether the airflow daemon is currently running.
Returns:
Type | Description |
---|---|
bool |
True if the airflow daemon is running, False otherwise. |
is_running: bool
property
readonly
Returns whether the airflow daemon is currently running.
Returns:
Type | Description |
---|---|
bool |
True if the daemon is running, False otherwise. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If port 8080 is occupied. |
log_file: str
property
readonly
Returns path to the airflow log file.
Returns:
Type | Description |
---|---|
str |
Path to the airflow log file. |
password_file: str
property
readonly
Returns path to the webserver password file.
Returns:
Type | Description |
---|---|
str |
Path to the webserver password file. |
pid_file: str
property
readonly
Returns path to the daemon PID file.
Returns:
Type | Description |
---|---|
str |
Path to the daemon PID file. |
__init__(self, **values)
special
Sets environment variables to configure airflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**values |
Any |
Values to set in the orchestrator. |
{} |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self.airflow_home = os.path.join(
io_utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(self.id),
)
self._set_env()
deprovision(self)
Stops the airflow daemon if necessary and tears down resources.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rmtree(self.airflow_home)
logger.info("Airflow spun down.")
prepare_or_run_pipeline(self, deployment, stack)
Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment and used for orchestration of the pipeline.
How it works:
A new airflow_dag is instantiated with the pipeline name and among others things the run schedule.
For each step of the pipeline a callable is created. This callable uses the run_step() method to execute the step. The parameters of this callable are pre-filled and an airflow step_operator is created within the dag. The dependencies to upstream steps are then configured.
Finally, the dag is fully complete and can be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Returns:
Type | Description |
---|---|
Any |
The Airflow DAG. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment
and used for orchestration of the pipeline.
How it works:
-------------
A new airflow_dag is instantiated with the pipeline name and among
others things the run schedule.
For each step of the pipeline a callable is created. This callable
uses the run_step() method to execute the step. The parameters of
this callable are pre-filled and an airflow step_operator is created
within the dag. The dependencies to upstream steps are then
configured.
Finally, the dag is fully complete and can be returned.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Returns:
The Airflow DAG.
"""
import airflow
from airflow.operators import python as airflow_python
# Instantiate and configure airflow Dag with name and schedule
airflow_dag = airflow.DAG(
dag_id=deployment.pipeline.name,
is_paused_upon_creation=False,
**self._translate_schedule(deployment.schedule),
)
# Dictionary mapping step names to airflow_operators. This will be needed
# to configure airflow operator dependencies
step_name_to_airflow_operator = {}
for step in deployment.steps.values():
# Create callable that will be used by airflow to execute the step
# within the orchestrated environment
def _step_callable(step_instance: "Step", **kwargs):
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Airflow orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
# Extract run name for the kwargs that will be passed to the
# callable
run_name = kwargs["ti"].get_dagrun().run_id
self._prepare_run(deployment=deployment)
self.run_step(step=step_instance, run_name=run_name)
self._cleanup_run()
# Create airflow python operator that contains the step callable
airflow_operator = airflow_python.PythonOperator(
dag=airflow_dag,
task_id=step.config.name,
provide_context=True,
python_callable=functools.partial(
_step_callable, step_instance=step
),
)
# Configure the current airflow operator to run after all upstream
# operators finished executing
step_name_to_airflow_operator[step.config.name] = airflow_operator
for upstream_step_name in step.spec.upstream_steps:
airflow_operator.set_upstream(
step_name_to_airflow_operator[upstream_step_name]
)
# Return the finished airflow dag
return airflow_dag
prepare_pipeline_deployment(self, deployment, stack)
Checks Airflow is running and copies DAG file to the DAGs directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If Airflow is not running or no DAG filepath runtime option is provided. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Checks Airflow is running and copies DAG file to the DAGs directory.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
RuntimeError: If Airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
if Environment.in_notebook():
raise RuntimeError(
"Unable to run the Airflow orchestrator from within a "
"notebook. Airflow requires a python file which contains a "
"global Airflow DAG object and therefore does not work with "
"notebooks. Please copy your ZenML pipeline code in a python "
"file and try again."
)
try:
dag_filepath = deployment.pipeline.extra[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)
Ensures that Airflow is running.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.exists(self.dags_directory):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Skip pipeline registration inside the airflow server process.
# When searching for DAGs, airflow imports the runner file in a
# randomly generated module. If we don't skip pipeline registration,
# it would fail by trying to register a pipeline with an existing
# name but different module sources for the steps.
with set_environment_variable(
key=ENV_ZENML_SKIP_PIPELINE_REGISTRATION, value="True"
):
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=get_source_root_path(),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
set_environment_variable(key, value)
Temporarily sets an environment variable.
The value will only be set while this context manager is active and will be reset to the previous value afterward.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The environment variable key. |
required |
value |
str |
The environment variable value. |
required |
Yields:
Type | Description |
---|---|
Iterator[NoneType] |
None. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@contextmanager
def set_environment_variable(key: str, value: str) -> Iterator[None]:
"""Temporarily sets an environment variable.
The value will only be set while this context manager is active and will
be reset to the previous value afterward.
Args:
key: The environment variable key.
value: The environment variable value.
Yields:
None.
"""
old_value = os.environ.get(key, None)
try:
os.environ[key] = value
yield
finally:
if old_value:
os.environ[key] = old_value
else:
del os.environ[key]
aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = ["boto3==1.21.0", "sagemaker==2.82.2"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
]
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
@property
def config(self) -> AWSContainerRegistryConfig:
"""Returns the `AWSContainerRegistryConfig` config.
Returns:
The configuration.
"""
return cast(AWSContainerRegistryConfig, self._config)
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(
r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.config.uri}."
)
return match.group(1)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
if not repo_exists:
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for example run a pipeline "
"using our Kubeflow orchestrator, ZenML will automatically build a "
f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create the "
f"repository `zenml-kubeflow` inside your amazon registry."
)
config: AWSContainerRegistryConfig
property
readonly
Returns the AWSContainerRegistryConfig
config.
Returns:
Type | Description |
---|---|
AWSContainerRegistryConfig |
The configuration. |
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
if not repo_exists:
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
flavors
special
AWS integration flavors.
aws_container_registry_flavor
AWS container registry flavor.
AWSContainerRegistryConfig (BaseContainerRegistryConfig)
pydantic-model
Configuration for AWS Container Registry.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
"""Configuration for AWS Container Registry."""
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)
AWS Container Registry flavor.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
"""AWS Container Registry flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_CONTAINER_REGISTRY_FLAVOR
@property
def config_class(self) -> Type[AWSContainerRegistryConfig]:
"""Config class for this flavor.
Returns:
The config class.
"""
return AWSContainerRegistryConfig
@property
def implementation_class(self) -> Type["AWSContainerRegistry"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.container_registries import (
AWSContainerRegistry,
)
return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] |
The config class. |
implementation_class: Type[AWSContainerRegistry]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSContainerRegistry] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
aws_secrets_manager_flavor
AWS secrets manager flavor.
AWSSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the AWS Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
region_name |
str |
The region name of the AWS Secrets Manager. |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the AWS Secrets Manager.
Attributes:
region_name: The region name of the AWS Secrets Manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
region_name: str
@classmethod
def _validate_scope(
cls,
scope: "SecretsManagerScope",
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the AWSSecretsManagerFlavor
.
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `AWSSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return AWS_SECRET_MANAGER_FLAVOR
@property
def config_class(self) -> Type[AWSSecretsManagerConfig]:
"""Config class for this flavor.
Returns:
Config class for this flavor.
"""
return AWSSecretsManagerConfig
@property
def implementation_class(self) -> Type["AWSSecretsManager"]:
"""Implementation class.
Returns:
Implementation class.
"""
from zenml.integrations.aws.secrets_managers import AWSSecretsManager
return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] |
Config class for this flavor. |
implementation_class: Type[AWSSecretsManager]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSSecretsManager] |
Implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
validate_aws_secret_name_or_namespace(name)
Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The /
character is only used internally to delimit
scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
sagemaker_step_operator_flavor
Amazon SageMaker step operator flavor.
SagemakerStepOperatorConfig (BaseStepOperatorConfig)
pydantic-model
Config for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
instance_type |
str |
The type of the compute instance where jobs will run. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig(BaseStepOperatorConfig):
"""Config for the Sagemaker step operator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
instance_type: The type of the compute instance where jobs will run.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
"""
role: str
instance_type: str
bucket: Optional[str] = None
experiment_name: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)
Flavor for the Sagemaker step operator.
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
"""Flavor for the Sagemaker step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def config_class(self) -> Type[SagemakerStepOperatorConfig]:
"""Returns SagemakerStepOperatorConfig config class.
Returns:
The config class.
"""
return SagemakerStepOperatorConfig
@property
def implementation_class(self) -> Type["SagemakerStepOperator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.step_operators import SagemakerStepOperator
return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]
property
readonly
Returns SagemakerStepOperatorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] |
The config class. |
implementation_class: Type[SagemakerStepOperator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerStepOperator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
secrets_managers
special
AWS Secrets Manager.
aws_secrets_manager
Implementation of the AWS Secrets Manager integration.
AWSSecretsManager (BaseSecretsManager)
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> AWSSecretsManagerConfig:
"""Returns the `AWSSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(AWSSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
"""Ensure that the client is connected to the AWS secrets manager.
Args:
region_name: the AWS region name
"""
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
def _get_secret_tags(
self, secret: BaseSecretSchema
) -> List[Dict[str, str]]:
"""Return a list of AWS secret tag values for a given secret.
Args:
secret: the secret object
Returns:
A list of AWS secret tag values
"""
metadata = self._get_secret_metadata(secret)
return [{"Key": k, "Value": v} for k, v in metadata.items()]
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return a list of AWS filters for the entire scope or just a scoped secret.
These filters can be used when querying the AWS Secrets Manager
for all secrets or for a single secret available in the configured
scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html
Example AWS filters for all secrets in the current (namespace) scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Example AWS filters for a particular secret in the current (namespace)
scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_secret_name"],
},
{
"Key: "tag-value",
"Values": ["my_secret"],
},
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of AWS filters uniquely identifying all secrets
or a named secret within the configured scope.
"""
metadata = self._get_secret_scope_metadata(secret_name)
filters: List[Dict[str, Any]] = []
for k, v in metadata.items():
filters.append(
{
"Key": "tag-key",
"Values": [
k,
],
}
)
filters.append(
{
"Key": "tag-value",
"Values": [
str(v),
],
}
)
return filters
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected(self.config.region_name)
filters: List[Dict[str, Any]] = []
prefix: Optional[str] = None
if self.config.scope == SecretsManagerScope.NONE:
# unscoped (legacy) secrets don't have tags. We want to filter out
# non-legacy secrets
filters = [
{
"Key": "tag-key",
"Values": [
"!zenml_scope",
],
},
]
if secret_name:
prefix = secret_name
else:
filters = self._get_secret_scope_filters()
if secret_name:
prefix = self._get_scoped_secret_name(secret_name)
else:
# add the name prefix to the filters to account for the fact
# that AWS does not do exact matching but prefix-matching on the
# filters
prefix = self._get_scoped_secret_name_prefix()
if prefix:
filters.append(
{
"Key": "name",
"Values": [
f"{prefix}",
],
}
)
# TODO [ENG-720]: Deal with pagination in the aws secret manager when
# listing all secrets
# TODO [ENG-721]: take out this magic maxresults number
response = self.CLIENT.list_secrets(MaxResults=100, Filters=filters)
results = []
for secret in response["SecretList"]:
name = self._get_unscoped_secret_name(secret["Name"])
# keep only the names that are in scope and filter by secret name,
# if one was given
if name and (not secret_name or secret_name == name):
results.append(name)
return results
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
config: AWSSecretsManagerConfig
property
readonly
Returns the AWSSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
AWSSecretsManagerConfig |
The configuration. |
delete_all_secrets(self)
Delete all existing secrets.
This method will force delete all your secrets. You will not be able to recover them once this method is called.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator)
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
"""
@property
def config(self) -> SagemakerStepOperatorConfig:
"""Returns the `SagemakerStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerStepOperatorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the SageMaker "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"SageMaker step operator."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_remote_components,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
docker_image_builder = PipelineDockerImageBuilder()
image_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment,
stack=stack,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
for step in steps_to_run:
step.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY] = image_digest
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY]
environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}
session = sagemaker.Session(default_bucket=self.config.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.config.role,
environment=environment,
instance_count=1,
instance_type=self.config.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = info.run_name.replace("_", "-")
experiment_config = {}
if self.config.experiment_name:
experiment_config = {
"ExperimentName": self.config.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
config: SagemakerStepOperatorConfig
property
readonly
Returns the SagemakerStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerStepOperatorConfig |
The configuration. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
launch(self, info, entrypoint_command)
Launches a step on SageMaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY]
environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}
session = sagemaker.Session(default_bucket=self.config.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.config.role,
environment=environment,
instance_count=1,
instance_type=self.config.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = info.run_name.replace("_", "-")
experiment_config = {}
if self.config.experiment_name:
experiment_config = {
"ExperimentName": self.config.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
docker_image_builder = PipelineDockerImageBuilder()
image_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment,
stack=stack,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
for step in steps_to_run:
step.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY] = image_digest
azure
special
Initialization of the ZenML Azure integration.
The Azure integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores,
and an io
module to handle file operations on Azure Blob Storage.
The Azure Step Operator integration submodule provides a way to run ZenML steps
in AzureML.
AzureIntegration (Integration)
Definition of Azure integration for ZenML.
Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
"""Definition of Azure integration for ZenML."""
NAME = AZURE
REQUIREMENTS = [
"adlfs==2021.10.0",
"azure-keyvault-keys",
"azure-keyvault-secrets",
"azure-identity",
"azureml-core==1.42.0.post1",
]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declares the flavors for the integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.azure.flavors import (
AzureArtifactStoreFlavor,
AzureMLStepOperatorFlavor,
AzureSecretsManagerFlavor,
)
return [
AzureArtifactStoreFlavor,
AzureSecretsManagerFlavor,
AzureMLStepOperatorFlavor,
]
flavors()
classmethod
Declares the flavors for the integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/azure/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declares the flavors for the integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.azure.flavors import (
AzureArtifactStoreFlavor,
AzureMLStepOperatorFlavor,
AzureSecretsManagerFlavor,
)
return [
AzureArtifactStoreFlavor,
AzureSecretsManagerFlavor,
AzureMLStepOperatorFlavor,
]
artifact_stores
special
Initialization of the Azure Artifact Store integration.
azure_artifact_store
Implementation of the Azure Artifact Store integration.
AzureArtifactStore (BaseArtifactStore, AuthenticationMixin)
Artifact Store for Microsoft Azure based artifacts.
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for Microsoft Azure based artifacts."""
_filesystem: Optional[adlfs.AzureBlobFileSystem] = None
@property
def config(self) -> AzureArtifactStoreConfig:
"""Returns the `AzureArtifactStoreConfig` config.
Returns:
The configuration.
"""
return cast(AzureArtifactStoreConfig, self._config)
@property
def filesystem(self) -> adlfs.AzureBlobFileSystem:
"""The adlfs filesystem to access this artifact store.
Returns:
The adlfs filesystem to access this artifact store.
"""
if not self._filesystem:
secret = self.get_authentication_secret(
expected_schema_type=AzureSecretSchema
)
credentials = secret.content if secret else {}
self._filesystem = adlfs.AzureBlobFileSystem(
**credentials,
anon=False,
use_listings_cache=False,
)
return self._filesystem
def _split_path(self, path: PathType) -> Tuple[str, str]:
"""Splits a path into the filesystem prefix and remainder.
Example:
```python
prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
print(prefix, remainder) # "az://" "my_container/test.txt"
```
Args:
path: The path to split.
Returns:
A tuple of the filesystem prefix and the remainder.
"""
path = convert_to_str(path)
prefix = ""
for potential_prefix in self.config.SUPPORTED_SCHEMES:
if path.startswith(potential_prefix):
prefix = potential_prefix
path = path[len(potential_prefix) :]
break
return prefix, path
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
prefix, _ = self._split_path(pattern)
return [
f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of files in the given directory.
"""
_, path = self._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a dictionary returned by the Azure filesystem.
Args:
file_dict: A dictionary returned by the Azure filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path, exist_ok=True)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
Stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
prefix, _ = self._split_path(top)
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
config: AzureArtifactStoreConfig
property
readonly
Returns the AzureArtifactStoreConfig
config.
Returns:
Type | Description |
---|---|
AzureArtifactStoreConfig |
The configuration. |
filesystem: AzureBlobFileSystem
property
readonly
The adlfs filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
AzureBlobFileSystem |
The adlfs filesystem to access this artifact store. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
prefix, _ = self._split_path(pattern)
return [
f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of files in the given directory. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of files in the given directory.
"""
_, path = self._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a dictionary returned by the Azure filesystem.
Args:
file_dict: A dictionary returned by the Azure filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path, exist_ok=True)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to remove. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Stat info. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
Stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
prefix, _ = self._split_path(top)
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
flavors
special
Azure integration flavors.
azure_artifact_store_flavor
Azure artifact store flavor.
AzureArtifactStoreConfig (BaseArtifactStoreConfig, AuthenticationConfigMixin)
pydantic-model
Configuration class for Azure Artifact Store.
Source code in zenml/integrations/azure/flavors/azure_artifact_store_flavor.py
class AzureArtifactStoreConfig(
BaseArtifactStoreConfig, AuthenticationConfigMixin
):
"""Configuration class for Azure Artifact Store."""
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"}
AzureArtifactStoreFlavor (BaseArtifactStoreFlavor)
Azure Artifact Store flavor.
Source code in zenml/integrations/azure/flavors/azure_artifact_store_flavor.py
class AzureArtifactStoreFlavor(BaseArtifactStoreFlavor):
"""Azure Artifact Store flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AZURE_ARTIFACT_STORE_FLAVOR
@property
def config_class(self) -> Type[AzureArtifactStoreConfig]:
"""Returns AzureArtifactStoreConfig config class.
Returns:
The config class.
"""
return AzureArtifactStoreConfig
@property
def implementation_class(self) -> Type["AzureArtifactStore"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.azure.artifact_stores import AzureArtifactStore
return AzureArtifactStore
config_class: Type[zenml.integrations.azure.flavors.azure_artifact_store_flavor.AzureArtifactStoreConfig]
property
readonly
Returns AzureArtifactStoreConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.azure.flavors.azure_artifact_store_flavor.AzureArtifactStoreConfig] |
The config class. |
implementation_class: Type[AzureArtifactStore]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AzureArtifactStore] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
azure_secrets_manager_flavor
Azure secrets manager flavor.
AzureSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the Azure Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
key_vault_name |
str |
Name of an Azure Key Vault that this secrets manager will use to store secrets. |
Source code in zenml/integrations/azure/flavors/azure_secrets_manager_flavor.py
class AzureSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the Azure Secrets Manager.
Attributes:
key_vault_name: Name of an Azure Key Vault that this secrets manager
will use to store secrets.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
key_vault_name: str
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_azure_secret_name_or_namespace(namespace, scope)
AzureSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the AzureSecretsManagerFlavor
.
Source code in zenml/integrations/azure/flavors/azure_secrets_manager_flavor.py
class AzureSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `AzureSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AZURE_SECRETS_MANAGER_FLAVOR
@property
def config_class(self) -> Type[AzureSecretsManagerConfig]:
"""Returns AzureSecretsManagerConfig config class.
Returns:
The config class.
"""
return AzureSecretsManagerConfig
@property
def implementation_class(self) -> Type["AzureSecretsManager"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.azure.secrets_managers import (
AzureSecretsManager,
)
return AzureSecretsManager
config_class: Type[zenml.integrations.azure.flavors.azure_secrets_manager_flavor.AzureSecretsManagerConfig]
property
readonly
Returns AzureSecretsManagerConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.azure.flavors.azure_secrets_manager_flavor.AzureSecretsManagerConfig] |
The config class. |
implementation_class: Type[AzureSecretsManager]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AzureSecretsManager] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
validate_azure_secret_name_or_namespace(name, scope)
Validate a secret name or namespace.
Azure secret names must contain only alphanumeric characters and the
character -
.
Given that we also save secret names and namespaces as labels, we are also limited by the 256 maximum size limitation that Azure imposes on label values. An arbitrary length of 100 characters is used here for the maximum size for the secret name and namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
scope |
SecretsManagerScope |
the current scope |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/azure/flavors/azure_secrets_manager_flavor.py
def validate_azure_secret_name_or_namespace(
name: str,
scope: SecretsManagerScope,
) -> None:
"""Validate a secret name or namespace.
Azure secret names must contain only alphanumeric characters and the
character `-`.
Given that we also save secret names and namespaces as labels, we are
also limited by the 256 maximum size limitation that Azure imposes on
label values. An arbitrary length of 100 characters is used here for
the maximum size for the secret name and namespace.
Args:
name: the secret name or namespace
scope: the current scope
Raises:
ValueError: if the secret name or namespace is invalid
"""
if scope == SecretsManagerScope.NONE:
# to preserve backwards compatibility, we don't validate the
# secret name for unscoped secrets.
return
if not re.fullmatch(r"[0-9a-zA-Z-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the character -."
)
if len(name) > 100:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 100 characters."
)
azureml_step_operator_flavor
AzureML step operator flavor.
AzureMLStepOperatorConfig (BaseStepOperatorConfig)
pydantic-model
Config for the AzureML step operator.
Attributes:
Name | Type | Description |
---|---|---|
subscription_id |
str |
The Azure account's subscription ID |
resource_group |
str |
The resource group to which the AzureML workspace is deployed. |
workspace_name |
str |
The name of the AzureML Workspace. |
compute_target_name |
str |
The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already. |
environment_name |
Optional[str] |
The name of the environment if there already exists one. |
tenant_id |
Optional[str] |
The Azure Tenant ID. |
service_principal_id |
Optional[str] |
The ID for the service principal that is created to allow apps to access secure resources. |
service_principal_password |
Optional[str] |
Password for the service principal. |
Source code in zenml/integrations/azure/flavors/azureml_step_operator_flavor.py
class AzureMLStepOperatorConfig(BaseStepOperatorConfig):
"""Config for the AzureML step operator.
Attributes:
subscription_id: The Azure account's subscription ID
resource_group: The resource group to which the AzureML workspace
is deployed.
workspace_name: The name of the AzureML Workspace.
compute_target_name: The name of the configured ComputeTarget.
An instance of it has to be created on the portal if it doesn't
exist already.
environment_name: The name of the environment if there
already exists one.
tenant_id: The Azure Tenant ID.
service_principal_id: The ID for the service principal that is created
to allow apps to access secure resources.
service_principal_password: Password for the service principal.
"""
subscription_id: str
resource_group: str
workspace_name: str
compute_target_name: str
# Environment
environment_name: Optional[str] = None
# Service principal authentication
# https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
tenant_id: Optional[str] = SecretField()
service_principal_id: Optional[str] = SecretField()
service_principal_password: Optional[str] = SecretField()
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
AzureMLStepOperatorFlavor (BaseStepOperatorFlavor)
Flavor for the AzureML step operator.
Source code in zenml/integrations/azure/flavors/azureml_step_operator_flavor.py
class AzureMLStepOperatorFlavor(BaseStepOperatorFlavor):
"""Flavor for the AzureML step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AZUREML_STEP_OPERATOR_FLAVOR
@property
def config_class(self) -> Type[AzureMLStepOperatorConfig]:
"""Returns AzureMLStepOperatorConfig config class.
Returns:
The config class.
"""
return AzureMLStepOperatorConfig
@property
def implementation_class(self) -> Type["AzureMLStepOperator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.azure.step_operators import AzureMLStepOperator
return AzureMLStepOperator
config_class: Type[zenml.integrations.azure.flavors.azureml_step_operator_flavor.AzureMLStepOperatorConfig]
property
readonly
Returns AzureMLStepOperatorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.azure.flavors.azureml_step_operator_flavor.AzureMLStepOperatorConfig] |
The config class. |
implementation_class: Type[AzureMLStepOperator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AzureMLStepOperator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
secrets_managers
special
Initialization of the Azure Secrets Manager integration.
azure_secrets_manager
Implementation of the Azure Secrets Manager integration.
AzureSecretsManager (BaseSecretsManager)
Class to interact with the Azure secrets manager.
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
class AzureSecretsManager(BaseSecretsManager):
"""Class to interact with the Azure secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> AzureSecretsManagerConfig:
"""Returns the `AzureSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(AzureSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls, vault_name: str) -> None:
if cls.CLIENT is None:
KVUri = f"https://{vault_name}.vault.azure.net"
credential = DefaultAzureCredential()
cls.CLIENT = SecretClient(vault_url=KVUri, credential=credential)
def validate_secret_name(self, name: str) -> None:
"""Validate a secret name.
Args:
name: the secret name
"""
validate_azure_secret_name_or_namespace(name, self.config.scope)
def _create_or_update_secret(self, secret: BaseSecretSchema) -> None:
"""Creates a new secret or updated an existing one.
Args:
secret: the secret to register or update
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
for key, value in secret.content.items():
encoded_key = base64.b64encode(
f"{secret.name}-{key}".encode()
).hex()
azure_secret_name = f"zenml-{encoded_key}"
self.CLIENT.set_secret(azure_secret_name, value)
self.CLIENT.update_secret_properties(
azure_secret_name,
tags={
ZENML_GROUP_KEY: secret.name,
ZENML_KEY_NAME: key,
ZENML_SCHEMA_NAME: secret.TYPE,
},
)
logger.debug(
"Secret `%s` written to the Azure Key Vault.",
azure_secret_name,
)
else:
azure_secret_name = self._get_scoped_secret_name(
secret.name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
)
self.CLIENT.set_secret(
azure_secret_name,
json.dumps(secret_to_dict(secret)),
)
self.CLIENT.update_secret_properties(
azure_secret_name,
tags=self._get_secret_metadata(secret),
)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.config.key_vault_name)
if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)
self._create_or_update_secret(secret)
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
ValueError: if the secret is named 'name'
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.config.key_vault_name)
zenml_secret: Optional[BaseSecretSchema] = None
if self.config.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Azure secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")
if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")
response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value
zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
if secret_contents:
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Azure secrets
try:
response = self.CLIENT.get_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
)
scope_tags = self._get_secret_scope_metadata(secret_name)
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope,
# even if it has the same name
if scope_tags.items() <= response.properties.tags.items():
zenml_secret = secret_from_dict(
json.loads(response.value), secret_name=secret_name
)
except ResourceNotFoundError:
pass
if not zenml_secret:
raise KeyError(f"Can't find the specified secret '{secret_name}'")
return zenml_secret
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
self._ensure_client_connected(self.config.key_vault_name)
set_of_secrets = set()
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
set_of_secrets.add(tags.get(ZENML_SECRET_NAME_LABEL))
return list(set_of_secrets)
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.config.key_vault_name)
if secret.name not in self.get_all_secret_keys():
raise KeyError(f"Can't find the specified secret '{secret.name}'")
self._create_or_update_secret(secret)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.config.key_vault_name)
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
# Go through all Azure secrets and delete the ones with the
# secret_name as label.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
else:
if secret_name not in self.get_all_secret_keys():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
self.CLIENT.begin_delete_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
).result()
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected(self.config.key_vault_name)
# List all secrets.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
logger.info(
"Deleted key-value pair {`%s`, `***`} from secret "
"`%s`",
secret_property.name,
tags.get(ZENML_GROUP_KEY),
)
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
self.CLIENT.begin_delete_secret(secret_property.name).result()
config: AzureSecretsManagerConfig
property
readonly
Returns the AzureSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
AzureSecretsManagerConfig |
The configuration. |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected(self.config.key_vault_name)
# List all secrets.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
logger.info(
"Deleted key-value pair {`%s`, `***`} from secret "
"`%s`",
secret_property.name,
tags.get(ZENML_GROUP_KEY),
)
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_secret(self, secret_name)
Delete an existing secret. by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret no longer exists |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.config.key_vault_name)
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
# Go through all Azure secrets and delete the ones with the
# secret_name as label.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
else:
if secret_name not in self.get_all_secret_keys():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
self.CLIENT.begin_delete_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
).result()
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
self._ensure_client_connected(self.config.key_vault_name)
set_of_secrets = set()
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.config.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
set_of_secrets.add(tags.get(ZENML_SECRET_NAME_LABEL))
return list(set_of_secrets)
get_secret(self, secret_name)
Get a secret by its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
ValueError |
if the secret is named 'name' |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
ValueError: if the secret is named 'name'
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.config.key_vault_name)
zenml_secret: Optional[BaseSecretSchema] = None
if self.config.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Azure secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")
if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")
response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value
zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
if secret_contents:
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Azure secrets
try:
response = self.CLIENT.get_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
)
scope_tags = self._get_secret_scope_metadata(secret_name)
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope,
# even if it has the same name
if scope_tags.items() <= response.properties.tags.items():
zenml_secret = secret_from_dict(
json.loads(response.value), secret_name=secret_name
)
except ResourceNotFoundError:
pass
if not zenml_secret:
raise KeyError(f"Can't find the specified secret '{secret_name}'")
return zenml_secret
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.config.key_vault_name)
if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)
self._create_or_update_secret(secret)
update_secret(self, secret)
Update an existing secret by creating new versions of the existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.config.key_vault_name)
if secret.name not in self.get_all_secret_keys():
raise KeyError(f"Can't find the specified secret '{secret.name}'")
self._create_or_update_secret(secret)
validate_secret_name(self, name)
Validate a secret name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name |
required |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def validate_secret_name(self, name: str) -> None:
"""Validate a secret name.
Args:
name: the secret name
"""
validate_azure_secret_name_or_namespace(name, self.config.scope)
step_operators
special
Initialization of AzureML Step Operator integration.
azureml_step_operator
Implementation of the ZenML AzureML Step Operator.
AzureMLStepOperator (BaseStepOperator)
Step operator to run a step on AzureML.
This class defines code that can set up an AzureML environment and run the ZenML entrypoint command in it.
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator):
"""Step operator to run a step on AzureML.
This class defines code that can set up an AzureML environment and run the
ZenML entrypoint command in it.
"""
@property
def config(self) -> AzureMLStepOperatorConfig:
"""Returns the `AzureMLStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(AzureMLStepOperatorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote artifact
store.
"""
def _validate_remote_artifact_store(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The AzureML step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the AzureML "
"step operator."
)
return True, ""
return StackValidator(
custom_validation_function=_validate_remote_artifact_store,
)
def _get_authentication(self) -> Optional[AbstractAuthentication]:
"""Returns the authentication object for the AzureML environment.
Returns:
The authentication object for the AzureML environment.
"""
if (
self.config.tenant_id
and self.config.service_principal_id
and self.config.service_principal_password
):
return ServicePrincipalAuthentication(
tenant_id=self.config.tenant_id,
service_principal_id=self.config.service_principal_id,
service_principal_password=self.config.service_principal_password,
)
return None
def _prepare_environment(
self,
workspace: Workspace,
docker_settings: "DockerSettings",
run_name: str,
) -> Environment:
"""Prepares the environment in which Azure will run all jobs.
Args:
workspace: The AzureML Workspace that has configuration
for a storage account, container registry among other
things.
docker_settings: The Docker settings for this step.
run_name: The name of the pipeline run that can be used
for naming environments and runs.
Returns:
The AzureML Environment object.
"""
docker_image_builder = PipelineDockerImageBuilder()
requirements_files = docker_image_builder._gather_requirements_files(
docker_settings=docker_settings,
stack=Client().active_stack,
)
requirements = list(
itertools.chain.from_iterable(
r[1].split("\n") for r in requirements_files
)
)
requirements.append(f"zenml=={zenml.__version__}")
logger.info(
"Using requirements for AzureML step operator environment: %s",
requirements,
)
if self.config.environment_name:
environment = Environment.get(
workspace=workspace, name=self.config.environment_name
)
if not environment.python.conda_dependencies:
environment.python.conda_dependencies = (
CondaDependencies.create(
python_version=ZenMLEnvironment.python_version()
)
)
for requirement in requirements:
environment.python.conda_dependencies.add_pip_package(
requirement
)
else:
environment = Environment(name=f"zenml-{run_name}")
environment.python.conda_dependencies = CondaDependencies.create(
pip_packages=requirements,
python_version=ZenMLEnvironment.python_version(),
)
if docker_settings.parent_image:
# replace the default azure base image
environment.docker.base_image = docker_settings.parent_image
environment_variables = {
"ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
}
# set credentials to access azure storage
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
environment_variables[key] = value
environment_variables[
ENV_ZENML_CONFIG_PATH
] = f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
environment_variables.update(docker_settings.environment)
environment.environment_variables = environment_variables
return environment
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on AzureML.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the AzureML step operator. If you want to run this step "
"operator on specific resources, you can do so by creating an "
"Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
"with a specific machine type and then updating this step "
"operator: `zenml step-operator update %s "
"--compute_target_name=<COMPUTE_TARGET_NAME>`",
self.name,
)
unused_docker_fields = [
"dockerfile",
"build_context_root",
"build_options",
"docker_target_repository",
"dockerignore",
"copy_files",
"copy_global_config",
]
docker_settings = info.pipeline.docker_settings
ignored_docker_fields = docker_settings.__fields_set__.intersection(
unused_docker_fields
)
if ignored_docker_fields:
logger.warning(
"The AzureML step operator currently does not support all "
"options defined in your Docker configuration. Ignoring all "
"values set for the attributes: %s",
ignored_docker_fields,
)
workspace = Workspace.get(
subscription_id=self.config.subscription_id,
resource_group=self.config.resource_group,
name=self.config.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
with _include_global_config(
build_context_root=source_directory,
load_config_path=PurePosixPath(
f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
),
):
environment = self._prepare_environment(
workspace=workspace,
docker_settings=docker_settings,
run_name=info.run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.config.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(
workspace=workspace, name=info.pipeline.name
)
run = experiment.submit(config=run_config)
run.display_name = info.run_name
run.wait_for_completion(show_output=True)
config: AzureMLStepOperatorConfig
property
readonly
Returns the AzureMLStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
AzureMLStepOperatorConfig |
The configuration. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote artifact store. |
launch(self, info, entrypoint_command)
Launches a step on AzureML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on AzureML.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the AzureML step operator. If you want to run this step "
"operator on specific resources, you can do so by creating an "
"Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
"with a specific machine type and then updating this step "
"operator: `zenml step-operator update %s "
"--compute_target_name=<COMPUTE_TARGET_NAME>`",
self.name,
)
unused_docker_fields = [
"dockerfile",
"build_context_root",
"build_options",
"docker_target_repository",
"dockerignore",
"copy_files",
"copy_global_config",
]
docker_settings = info.pipeline.docker_settings
ignored_docker_fields = docker_settings.__fields_set__.intersection(
unused_docker_fields
)
if ignored_docker_fields:
logger.warning(
"The AzureML step operator currently does not support all "
"options defined in your Docker configuration. Ignoring all "
"values set for the attributes: %s",
ignored_docker_fields,
)
workspace = Workspace.get(
subscription_id=self.config.subscription_id,
resource_group=self.config.resource_group,
name=self.config.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
with _include_global_config(
build_context_root=source_directory,
load_config_path=PurePosixPath(
f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
),
):
environment = self._prepare_environment(
workspace=workspace,
docker_settings=docker_settings,
run_name=info.run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.config.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(
workspace=workspace, name=info.pipeline.name
)
run = experiment.submit(config=run_config)
run.display_name = info.run_name
run.wait_for_completion(show_output=True)
constants
Constants for ZenML integrations.
dash
special
Initialization of the Dash integration.
DashIntegration (Integration)
Definition of Dash integration for ZenML.
Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
"""Definition of Dash integration for ZenML."""
NAME = DASH
REQUIREMENTS = [
"dash>=2.0.0",
"dash-cytoscape>=0.3.0",
"dash-bootstrap-components>=1.0.1",
"jupyter-dash>=0.4.2",
]
visualizers
special
Initialization of the Pipeline Run Visualizer.
pipeline_run_lineage_visualizer
Implementation of the pipeline run lineage visualizer.
PipelineRunLineageVisualizer (BaseVisualizer)
Implementation of a lineage diagram via the dash and dash-cytoscape libraries.
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BaseVisualizer):
"""Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""
ARTIFACT_PREFIX = "artifact_"
STEP_PREFIX = "step_"
STATUS_CLASS_MAPPING = {
ExecutionStatus.CACHED: "green",
ExecutionStatus.FAILED: "red",
ExecutionStatus.RUNNING: "yellow",
ExecutionStatus.COMPLETED: "blue",
}
def visualize(
self,
object: PipelineRunView,
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Args:
object: The pipeline run to visualize.
magic: If True, the visualization is rendered in a magic mode.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
The Dash application.
"""
# flake8: noqa: C901
external_stylesheets = [
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
]
if magic:
if Environment.in_notebook:
# Only import jupyter_dash in this case
from jupyter_dash import JupyterDash # noqa
JupyterDash.infer_jupyter_proxy_config()
app = JupyterDash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = "inline"
else:
cli_utils.warning(
"Cannot set magic flag in non-notebook environments."
)
else:
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = None
graph = LineageGraph()
graph.generate_run_nodes_and_edges(object)
first_step_id = graph.root_step_id
# Parse lineage graph nodes
nodes = []
for node in graph.nodes:
node_dict = node.dict()
node_data = node_dict.pop("data")
node_dict = {**node_dict, **node_data}
node_dict["label"] = node_dict["name"]
classes = self.STATUS_CLASS_MAPPING[node.data.status]
if isinstance(node, ArtifactNode):
classes = "rectangle " + classes
node_dict["label"] += f" ({node_dict['artifact_data_type']})"
dash_node = {"data": node_dict, "classes": classes}
nodes.append(dash_node)
# Parse lineage graph edges
node_mapping = {node.id: node for node in graph.nodes}
edges = []
for edge in graph.edges:
source_node = node_mapping[edge.source]
if isinstance(source_node, StepNode):
is_input_artifact = False
step_node = node_mapping[edge.source]
artifact_node = node_mapping[edge.target]
else:
is_input_artifact = True
step_node = node_mapping[edge.target]
artifact_node = node_mapping[edge.source]
assert isinstance(artifact_node, ArtifactNode)
artifact_is_cached = artifact_node.data.is_cached
if is_input_artifact and artifact_is_cached:
edge_status = self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]
else:
edge_status = self.STATUS_CLASS_MAPPING[step_node.data.status]
edge_style = "dashed" if artifact_node.data.is_cached else "solid"
edges.append(
{
"data": edge.dict(),
"classes": f"edge-arrow {edge_status} {edge_style}",
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h2"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph.
Args:
data_list: The selected node data.
Returns:
str: The selected node data.
"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
if data["type"] == "artifact":
text += f"### Artifact '{data['name']}'" + "\n\n"
text += "#### Attributes:" + "\n\n"
for item in [
"execution_id",
"status",
"artifact_data_type",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += f"### Step '{data['name']}'" + "\n\n"
text += "#### Attributes:" + "\n\n"
for item in [
"execution_id",
"status",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
if data["inputs"]:
text += "#### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
if data["outputs"]:
text += "#### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
if data["parameters"]:
text += "#### Parameters:" + "\n\n"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
if data["configuration"]:
text += "#### Configuration:" + "\n\n"
for k, v in data["configuration"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout.
Args:
n_clicks: The number of clicks on the reset button.
Returns:
The zoom and the elements.
"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
if mode is not None:
app.run_server(mode=mode)
app.run_server()
return app
visualize(self, object, magic=False, *args, **kwargs)
Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineRunView |
The pipeline run to visualize. |
required |
magic |
bool |
If True, the visualization is rendered in a magic mode. |
False |
*args |
Any |
Additional positional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Dash |
The Dash application. |
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
self,
object: PipelineRunView,
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Args:
object: The pipeline run to visualize.
magic: If True, the visualization is rendered in a magic mode.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
The Dash application.
"""
# flake8: noqa: C901
external_stylesheets = [
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
]
if magic:
if Environment.in_notebook:
# Only import jupyter_dash in this case
from jupyter_dash import JupyterDash # noqa
JupyterDash.infer_jupyter_proxy_config()
app = JupyterDash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = "inline"
else:
cli_utils.warning(
"Cannot set magic flag in non-notebook environments."
)
else:
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = None
graph = LineageGraph()
graph.generate_run_nodes_and_edges(object)
first_step_id = graph.root_step_id
# Parse lineage graph nodes
nodes = []
for node in graph.nodes:
node_dict = node.dict()
node_data = node_dict.pop("data")
node_dict = {**node_dict, **node_data}
node_dict["label"] = node_dict["name"]
classes = self.STATUS_CLASS_MAPPING[node.data.status]
if isinstance(node, ArtifactNode):
classes = "rectangle " + classes
node_dict["label"] += f" ({node_dict['artifact_data_type']})"
dash_node = {"data": node_dict, "classes": classes}
nodes.append(dash_node)
# Parse lineage graph edges
node_mapping = {node.id: node for node in graph.nodes}
edges = []
for edge in graph.edges:
source_node = node_mapping[edge.source]
if isinstance(source_node, StepNode):
is_input_artifact = False
step_node = node_mapping[edge.source]
artifact_node = node_mapping[edge.target]
else:
is_input_artifact = True
step_node = node_mapping[edge.target]
artifact_node = node_mapping[edge.source]
assert isinstance(artifact_node, ArtifactNode)
artifact_is_cached = artifact_node.data.is_cached
if is_input_artifact and artifact_is_cached:
edge_status = self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]
else:
edge_status = self.STATUS_CLASS_MAPPING[step_node.data.status]
edge_style = "dashed" if artifact_node.data.is_cached else "solid"
edges.append(
{
"data": edge.dict(),
"classes": f"edge-arrow {edge_status} {edge_style}",
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h2"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph.
Args:
data_list: The selected node data.
Returns:
str: The selected node data.
"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
if data["type"] == "artifact":
text += f"### Artifact '{data['name']}'" + "\n\n"
text += "#### Attributes:" + "\n\n"
for item in [
"execution_id",
"status",
"artifact_data_type",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += f"### Step '{data['name']}'" + "\n\n"
text += "#### Attributes:" + "\n\n"
for item in [
"execution_id",
"status",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
if data["inputs"]:
text += "#### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
if data["outputs"]:
text += "#### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
if data["parameters"]:
text += "#### Parameters:" + "\n\n"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
if data["configuration"]:
text += "#### Configuration:" + "\n\n"
for k, v in data["configuration"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout.
Args:
n_clicks: The number of clicks on the reset button.
Returns:
The zoom and the elements.
"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
if mode is not None:
app.run_server(mode=mode)
app.run_server()
return app
deepchecks
special
Deepchecks integration for ZenML.
The Deepchecks integration provides a way to validate your data in your pipelines. It includes a way to detect data anomalies and define checks to ensure quality of data.
The integration includes custom materializers to store Deepchecks SuiteResults
and
a visualizer to visualize the results in an easy way on a notebook and in your
browser.
DeepchecksIntegration (Integration)
Definition of Deepchecks integration for ZenML.
Source code in zenml/integrations/deepchecks/__init__.py
class DeepchecksIntegration(Integration):
"""Definition of [Deepchecks](https://github.com/deepchecks/deepchecks) integration for ZenML."""
NAME = DEEPCHECKS
REQUIREMENTS = ["deepchecks[vision]==0.8.0", "torchvision==0.11.2"]
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.deepchecks import materializers # noqa
from zenml.integrations.deepchecks import visualizers # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Deepchecks integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.deepchecks.flavors import (
DeepchecksDataValidatorFlavor,
)
return [DeepchecksDataValidatorFlavor]
activate()
staticmethod
Activate the Deepchecks integration.
Source code in zenml/integrations/deepchecks/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.deepchecks import materializers # noqa
from zenml.integrations.deepchecks import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Deepchecks integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/deepchecks/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Deepchecks integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.deepchecks.flavors import (
DeepchecksDataValidatorFlavor,
)
return [DeepchecksDataValidatorFlavor]
data_validators
special
Initialization of the Deepchecks data validator for ZenML.
deepchecks_data_validator
Implementation of the Deepchecks data validator.
DeepchecksDataValidator (BaseDataValidator)
Deepchecks data validator stack component.
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
class DeepchecksDataValidator(BaseDataValidator):
"""Deepchecks data validator stack component."""
NAME: ClassVar[str] = "Deepchecks"
@staticmethod
def _split_checks(
check_list: Sequence[str],
) -> Tuple[Sequence[str], Sequence[str]]:
"""Split a list of check identifiers in two lists, one for tabular and one for computer vision checks.
Args:
check_list: A list of check identifiers.
Returns:
List of tabular check identifiers and list of computer vision
check identifiers.
"""
tabular_checks = list(
filter(
lambda check: DeepchecksValidationCheck.is_tabular_check(check),
check_list,
)
)
vision_checks = list(
filter(
lambda check: DeepchecksValidationCheck.is_vision_check(check),
check_list,
)
)
return tabular_checks, vision_checks
# flake8: noqa: C901
@classmethod
def _create_and_run_check_suite(
cls,
check_enum: Type[DeepchecksValidationCheck],
reference_dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[
Union[pd.DataFrame, DataLoader[Any]]
] = None,
model: Optional[Union[ClassifierMixin, Module]] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
) -> SuiteResult:
"""Create and run a Deepchecks check suite corresponding to the input parameters.
This method contains generic logic common to all Deepchecks data
validator methods that validates the input arguments and uses them to
generate and run a Deepchecks check suite.
Args:
check_enum: ZenML enum type grouping together Deepchecks checks with
the same characteristics. This is used to generate a default
list of checks, if a custom list isn't provided via the
`check_list` argument.
reference_dataset: Primary (reference) dataset argument used during
validation.
comparison_dataset: Optional secondary (comparison) dataset argument
used during comparison checks.
model: Optional model argument used during validation.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the list of Deepchecks checks to be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
Returns:
Deepchecks SuiteResult object with the Suite run results.
Raises:
TypeError: If the datasets, model and check list arguments combine
data types and/or checks from different categories (tabular and
computer vision).
"""
# Detect what type of check to perform (tabular or computer vision) from
# the dataset/model datatypes and the check list. At the same time,
# validate the combination of data types used for dataset and model
# arguments and the check list.
is_tabular = False
is_vision = False
for dataset in [reference_dataset, comparison_dataset]:
if dataset is None:
continue
if isinstance(dataset, pd.DataFrame):
is_tabular = True
elif isinstance(dataset, DataLoader):
is_vision = True
else:
raise TypeError(
f"Unsupported dataset data type found: {type(dataset)}. "
f"Supported data types are {str(pd.DataFrame)} for tabular "
f"data and {str(DataLoader)} for computer vision data."
)
if model:
if isinstance(model, ClassifierMixin):
is_tabular = True
elif isinstance(model, Module):
is_vision = True
else:
raise TypeError(
f"Unsupported model data type found: {type(model)}. "
f"Supported data types are {str(ClassifierMixin)} for "
f"tabular data and {str(Module)} for computer vision "
f"data."
)
if is_tabular and is_vision:
raise TypeError(
f"Tabular and computer vision data types used for datasets and "
f"models cannot be mixed. They must all belong to the same "
f"category. Supported data types for tabular data are "
f"{str(pd.DataFrame)} for datasets and {str(ClassifierMixin)} "
f"for models. Supported data types for computer vision data "
f"are {str(pd.DataFrame)} for datasets and and {str(Module)} "
f"for models."
)
if not check_list:
# default to executing all the checks listed in the supplied
# checks enum type if a custom check list is not supplied
tabular_checks, vision_checks = cls._split_checks(
check_enum.values()
)
if is_tabular:
check_list = tabular_checks
vision_checks = []
else:
check_list = vision_checks
tabular_checks = []
else:
tabular_checks, vision_checks = cls._split_checks(check_list)
if tabular_checks and vision_checks:
raise TypeError(
f"The check list cannot mix tabular checks "
f"({tabular_checks}) and computer vision checks ("
f"{vision_checks})."
)
if is_tabular and vision_checks:
raise TypeError(
f"Tabular data types used for datasets and models can only "
f"be used with tabular validation checks. The following "
f"computer vision checks included in the check list are "
f"not valid: {vision_checks}."
)
if is_vision and tabular_checks:
raise TypeError(
f"Computer vision data types used for datasets and models "
f"can only be used with computer vision validation checks. "
f"The following tabular checks included in the check list "
f"are not valid: {tabular_checks}."
)
check_classes = map(
lambda check: (
check,
check_enum.get_check_class(check),
),
check_list,
)
# use the pipeline name and the step name to generate a unique suite
# name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
suite_name = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
# if not running inside a pipeline step, use random values
suite_name = f"suite_{random_str(5)}"
if is_tabular:
dataset_class = TabularData
suite_class = TabularSuite
full_suite = full_tabular_suite()
else:
dataset_class = VisionData
suite_class = VisionSuite
full_suite = full_vision_suite()
train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
test_dataset = None
if comparison_dataset is not None:
test_dataset = dataset_class(comparison_dataset, **dataset_kwargs)
suite = suite_class(name=suite_name)
# Some Deepchecks checks require a minimum configuration such as
# conditions to be configured (see https://docs.deepchecks.com/stable/user-guide/general/customizations/examples/plot_configure_check_conditions.html#sphx-glr-user-guide-general-customizations-examples-plot-configure-check-conditions-py)
# for their execution to have meaning. For checks that don't have
# custom configuration attributes explicitly specified in the
# `check_kwargs` input parameter, we use the default check
# instances extracted from the full suite shipped with Deepchecks.
default_checks = {
check.__class__: check for check in full_suite.checks.values()
}
for check_name, check_class in check_classes:
extra_kwargs = check_kwargs.get(check_name, {})
default_check = default_checks.get(check_class)
check: BaseCheck
if extra_kwargs or not default_check:
check = check_class(**check_kwargs)
else:
check = default_check
# extract the condition kwargs from the check kwargs
for arg_name, condition_kwargs in extra_kwargs.items():
if not arg_name.startswith("condition_") or not isinstance(
condition_kwargs, dict
):
continue
condition_method = getattr(check, f"add_{arg_name}", None)
if not condition_method or not callable(condition_method):
logger.warning(
f"Deepchecks check type {check.__class__} has no "
f"condition named {arg_name}. Ignoring the check "
f"argument."
)
continue
condition_method(**condition_kwargs)
suite.add(check)
return suite.run(
train_dataset=train_dataset,
test_dataset=test_dataset,
model=model,
**run_kwargs,
)
def data_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> SuiteResult:
"""Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the `comparison_dataset` argument.
The `check_list` argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
`DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
data integrity checks will be performed on the input data. See
`DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available data drift checks will be performed on the input
data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
comparison_dataset: Optional second dataset to be used for data
comparison checks (e.g data drift checks).
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
`DeepchecksDataIntegrityCheck` enum values should be used for
single data validation checks and `DeepchecksDataDriftCheck`
enum values for data comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksDataIntegrityCheck
else:
check_enum = DeepchecksDataDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
def model_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
model: Union[ClassifierMixin, Module],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion
matrix validation, performance reports, model error analyses, etc).
A second dataset is required for model performance comparison tests
(i.e. tests that identify changes in a model behavior by comparing how
it performs on two different datasets).
The `check_list` argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
`DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
model: Target model to be validated.
comparison_dataset: Optional second dataset to be used for model
comparison checks.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
`DeepchecksModelValidationCheck` enum values should be used for
model validation checks and `DeepchecksModelDriftCheck` enum
values for model comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksModelValidationCheck
else:
check_enum = DeepchecksModelDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
data_validation(self, dataset, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)
Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the comparison_dataset
argument.
The check_list
argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
DeepchecksDataIntegrityCheck
and DeepchecksDataDriftCheck
enum
values. If omitted:
-
if the
comparison_dataset
is omitted, a suite with all available data integrity checks will be performed on the input data. SeeDeepchecksDataIntegrityCheck
for a list of Deepchecks builtin checks that are compatible with this method. -
if the
comparison_dataset
is supplied, a suite with all available data drift checks will be performed on the input data. SeeDeepchecksDataDriftCheck
for a list of Deepchecks builtin checks that are compatible with this method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]] |
Target dataset to be validated. |
required |
comparison_dataset |
Optional[Any] |
Optional second dataset to be used for data comparison checks (e.g data drift checks). |
None |
check_list |
Optional[Sequence[str]] |
Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
|
None |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
{} |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
{} |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
{} |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks SuiteResult with the results of the validation. |
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def data_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> SuiteResult:
"""Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the `comparison_dataset` argument.
The `check_list` argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
`DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
data integrity checks will be performed on the input data. See
`DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available data drift checks will be performed on the input
data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
comparison_dataset: Optional second dataset to be used for data
comparison checks (e.g data drift checks).
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
`DeepchecksDataIntegrityCheck` enum values should be used for
single data validation checks and `DeepchecksDataDriftCheck`
enum values for data comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksDataIntegrityCheck
else:
check_enum = DeepchecksDataDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
model_validation(self, dataset, model, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)
Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion matrix validation, performance reports, model error analyses, etc). A second dataset is required for model performance comparison tests (i.e. tests that identify changes in a model behavior by comparing how it performs on two different datasets).
The check_list
argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
DeepchecksModelValidationCheck
and DeepchecksModelDriftCheck
enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]] |
Target dataset to be validated. |
required |
model |
Union[sklearn.base.ClassifierMixin, torch.nn.modules.module.Module] |
Target model to be validated. |
required |
comparison_dataset |
Optional[Any] |
Optional second dataset to be used for model comparison checks. |
None |
check_list |
Optional[Sequence[str]] |
Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
|
None |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor. |
{} |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
{} |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
{} |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
Any |
A Deepchecks SuiteResult with the results of the validation. |
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def model_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
model: Union[ClassifierMixin, Module],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion
matrix validation, performance reports, model error analyses, etc).
A second dataset is required for model performance comparison tests
(i.e. tests that identify changes in a model behavior by comparing how
it performs on two different datasets).
The `check_list` argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
`DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
model: Target model to be validated.
comparison_dataset: Optional second dataset to be used for model
comparison checks.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
`DeepchecksModelValidationCheck` enum values should be used for
model validation checks and `DeepchecksModelDriftCheck` enum
values for model comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksModelValidationCheck
else:
check_enum = DeepchecksModelDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
flavors
special
Deepchecks integration flavors.
deepchecks_data_validator_flavor
Deepchecks data validator flavor.
DeepchecksDataValidatorFlavor (BaseDataValidatorFlavor)
Flavor of the Deepchecks data validator.
Source code in zenml/integrations/deepchecks/flavors/deepchecks_data_validator_flavor.py
class DeepchecksDataValidatorFlavor(BaseDataValidatorFlavor):
"""Flavor of the Deepchecks data validator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return DEEPCHECKS_DATA_VALIDATOR_FLAVOR
@property
def implementation_class(self) -> Type["DeepchecksDataValidator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.deepchecks.data_validators import (
DeepchecksDataValidator,
)
return DeepchecksDataValidator
implementation_class: Type[DeepchecksDataValidator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[DeepchecksDataValidator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
materializers
special
Deepchecks materializers.
deepchecks_dataset_materializer
Implementation of Deepchecks dataset materializer.
DeepchecksDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from Deepchecks dataset.
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
class DeepchecksDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from Deepchecks dataset."""
ASSOCIATED_TYPES = (Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads pandas dataframes and creates deepchecks.Dataset from it.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks Dataset.
"""
super().handle_input(data_type)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
df = pandas_materializer.handle_input(data_type)
# Recreate from pandas dataframe
return Dataset(df)
def handle_return(self, df: Dataset) -> None:
"""Serializes pandas dataframe within a Dataset object.
Args:
df: A deepchecks.Dataset object.
"""
super().handle_return(df)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
pandas_materializer.handle_return(df.data)
handle_input(self, data_type)
Reads pandas dataframes and creates deepchecks.Dataset from it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Dataset |
A Deepchecks Dataset. |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads pandas dataframes and creates deepchecks.Dataset from it.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks Dataset.
"""
super().handle_input(data_type)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
df = pandas_materializer.handle_input(data_type)
# Recreate from pandas dataframe
return Dataset(df)
handle_return(self, df)
Serializes pandas dataframe within a Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
Dataset |
A deepchecks.Dataset object. |
required |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_return(self, df: Dataset) -> None:
"""Serializes pandas dataframe within a Dataset object.
Args:
df: A deepchecks.Dataset object.
"""
super().handle_return(df)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
pandas_materializer.handle_return(df.data)
deepchecks_results_materializer
Implementation of Deepchecks suite results materializer.
DeepchecksResultMaterializer (BaseMaterializer)
Materializer to read data to and from CheckResult and SuiteResult objects.
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
class DeepchecksResultMaterializer(BaseMaterializer):
"""Materializer to read data to and from CheckResult and SuiteResult objects."""
ASSOCIATED_TYPES = (
CheckResult,
SuiteResult,
)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
def handle_input(
self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
"""Reads a Deepchecks check or suite result from a serialized JSON file.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks CheckResult or SuiteResult.
Raises:
RuntimeError: if the input data type is not supported.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
json_res = io_utils.read_file_contents_as_string(filepath)
if data_type == SuiteResult:
res = SuiteResult.from_json(json_res)
elif data_type == CheckResult:
res = CheckResult.from_json(json_res)
else:
raise RuntimeError(f"Unknown data type: {data_type}")
return res
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Creates a JSON serialization for a CheckResult or SuiteResult.
Args:
result: A Deepchecks CheckResult or SuiteResult.
"""
super().handle_return(result)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
serialized_json = result.to_json(True)
io_utils.write_file_contents_as_string(filepath, serialized_json)
handle_input(self, data_type)
Reads a Deepchecks check or suite result from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A Deepchecks CheckResult or SuiteResult. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the input data type is not supported. |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_input(
self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
"""Reads a Deepchecks check or suite result from a serialized JSON file.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks CheckResult or SuiteResult.
Raises:
RuntimeError: if the input data type is not supported.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
json_res = io_utils.read_file_contents_as_string(filepath)
if data_type == SuiteResult:
res = SuiteResult.from_json(json_res)
elif data_type == CheckResult:
res = CheckResult.from_json(json_res)
else:
raise RuntimeError(f"Unknown data type: {data_type}")
return res
handle_return(self, result)
Creates a JSON serialization for a CheckResult or SuiteResult.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A Deepchecks CheckResult or SuiteResult. |
required |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Creates a JSON serialization for a CheckResult or SuiteResult.
Args:
result: A Deepchecks CheckResult or SuiteResult.
"""
super().handle_return(result)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
serialized_json = result.to_json(True)
io_utils.write_file_contents_as_string(filepath, serialized_json)
steps
special
Initialization of the Deepchecks Standard Steps.
deepchecks_data_drift
Implementation of the Deepchecks data drift validation step.
DeepchecksDataDriftCheckStep (BaseStep)
Deepchecks data drift validator step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStep(BaseStep):
"""Deepchecks data drift validator step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
params: DeepchecksDataDriftCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data drift validator step.
Args:
reference_dataset: Reference dataset for the data drift check.
target_dataset: Target dataset to be used for the data drift check.
params: The parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Parameter class for the Deepchecks data drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]] |
Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepParameters(BaseParameters):
"""Parameter class for the Deepchecks data drift validator step.
Attributes:
check_list: Optional list of DeepchecksDataDriftCheck identifiers
specifying the subset of Deepchecks data drift checks to be
performed. If not supplied, the entire set of data drift checks will
be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, params)
Main entrypoint for the Deepchecks data drift validator step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
Reference dataset for the data drift check. |
required |
target_dataset |
DataFrame |
Target dataset to be used for the data drift check. |
required |
params |
DeepchecksDataDriftCheckStepParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
params: DeepchecksDataDriftCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data drift validator step.
Args:
reference_dataset: Reference dataset for the data drift check.
target_dataset: Target dataset to be used for the data drift check.
params: The parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
DeepchecksDataDriftCheckStepParameters (BaseParameters)
pydantic-model
Parameter class for the Deepchecks data drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]] |
Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepParameters(BaseParameters):
"""Parameter class for the Deepchecks data drift validator step.
Attributes:
check_list: Optional list of DeepchecksDataDriftCheck identifiers
specifying the subset of Deepchecks data drift checks to be
performed. If not supplied, the entire set of data drift checks will
be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_drift_check_step(step_name, params)
Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.
The returned DeepchecksDataDriftCheckStep can be used in a pipeline to run data drift checks on two input pd.DataFrame and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
DeepchecksDataDriftCheckStepParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksDataDriftCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def deepchecks_data_drift_check_step(
step_name: str,
params: DeepchecksDataDriftCheckStepParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.
The returned DeepchecksDataDriftCheckStep can be used in a pipeline to
run data drift checks on two input pd.DataFrame and return the results
as a Deepchecks SuiteResult object.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a DeepchecksDataDriftCheckStep step instance
"""
return clone_step(DeepchecksDataDriftCheckStep, step_name)(params=params)
deepchecks_data_integrity
Implementation of the Deepchecks data integrity validation step.
DeepchecksDataIntegrityCheckStep (BaseStep)
Deepchecks data integrity validator step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStep(BaseStep):
"""Deepchecks data integrity validator step."""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
params: DeepchecksDataIntegrityCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data integrity validator step.
Args:
dataset: a Pandas DataFrame to validate
params: The parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=dataset,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Parameters class for the Deepchecks data integrity validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]] |
Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepParameters(BaseParameters):
"""Parameters class for the Deepchecks data integrity validator step.
Attributes:
check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
specifying the subset of Deepchecks data integrity checks to be
performed. If not supplied, the entire set of data integrity checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, params)
Main entrypoint for the Deepchecks data integrity validator step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a Pandas DataFrame to validate |
required |
params |
DeepchecksDataIntegrityCheckStepParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
params: DeepchecksDataIntegrityCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data integrity validator step.
Args:
dataset: a Pandas DataFrame to validate
params: The parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=dataset,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
DeepchecksDataIntegrityCheckStepParameters (BaseParameters)
pydantic-model
Parameters class for the Deepchecks data integrity validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]] |
Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepParameters(BaseParameters):
"""Parameters class for the Deepchecks data integrity validator step.
Attributes:
check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
specifying the subset of Deepchecks data integrity checks to be
performed. If not supplied, the entire set of data integrity checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_integrity_check_step(step_name, params)
Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.
The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to run data integrity checks on an input pd.DataFrame and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
DeepchecksDataIntegrityCheckStepParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksDataIntegrityCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def deepchecks_data_integrity_check_step(
step_name: str,
params: DeepchecksDataIntegrityCheckStepParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.
The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to
run data integrity checks on an input pd.DataFrame and return the results
as a Deepchecks SuiteResult object.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a DeepchecksDataIntegrityCheckStep step instance
"""
return clone_step(DeepchecksDataIntegrityCheckStep, step_name)(
params=params
)
deepchecks_model_drift
Implementation of the Deepchecks model drift validation step.
DeepchecksModelDriftCheckStep (BaseStep)
Deepchecks model drift step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStep(BaseStep):
"""Deepchecks model drift step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
model: ClassifierMixin,
params: DeepchecksModelDriftCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model drift step.
Args:
reference_dataset: Reference dataset for the model drift check.
target_dataset: Target dataset to be used for the model drift check.
model: a scikit-learn model to validate
params: the parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
model=model,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Parameters class for the Deepchecks model drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]] |
Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepParameters(BaseParameters):
"""Parameters class for the Deepchecks model drift validator step.
Attributes:
check_list: Optional list of DeepchecksModelDriftCheck identifiers
specifying the subset of Deepchecks model drift checks to be
performed. If not supplied, the entire set of model drift checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, model, params)
Main entrypoint for the Deepchecks model drift step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
Reference dataset for the model drift check. |
required |
target_dataset |
DataFrame |
Target dataset to be used for the model drift check. |
required |
model |
ClassifierMixin |
a scikit-learn model to validate |
required |
params |
DeepchecksModelDriftCheckStepParameters |
the parameters for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
model: ClassifierMixin,
params: DeepchecksModelDriftCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model drift step.
Args:
reference_dataset: Reference dataset for the model drift check.
target_dataset: Target dataset to be used for the model drift check.
model: a scikit-learn model to validate
params: the parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
model=model,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
DeepchecksModelDriftCheckStepParameters (BaseParameters)
pydantic-model
Parameters class for the Deepchecks model drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]] |
Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepParameters(BaseParameters):
"""Parameters class for the Deepchecks model drift validator step.
Attributes:
check_list: Optional list of DeepchecksModelDriftCheck identifiers
specifying the subset of Deepchecks model drift checks to be
performed. If not supplied, the entire set of model drift checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_drift_check_step(step_name, params)
Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.
The returned DeepchecksModelDriftCheckStep can be used in a pipeline to run model drift checks on two input pd.DataFrame datasets and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
DeepchecksModelDriftCheckStepParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksModelDriftCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def deepchecks_model_drift_check_step(
step_name: str,
params: DeepchecksModelDriftCheckStepParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.
The returned DeepchecksModelDriftCheckStep can be used in a pipeline to
run model drift checks on two input pd.DataFrame datasets and an input
scikit-learn ClassifierMixin model and return the results as a Deepchecks
SuiteResult object.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a DeepchecksModelDriftCheckStep step instance
"""
return clone_step(DeepchecksModelDriftCheckStep, step_name)(params=params)
deepchecks_model_validation
Implementation of the Deepchecks model validation validation step.
DeepchecksModelValidationCheckStep (BaseStep)
Deepchecks model validation step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStep(BaseStep):
"""Deepchecks model validation step."""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: ClassifierMixin,
params: DeepchecksModelValidationCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model validation step.
Args:
dataset: a Pandas DataFrame to use for the validation
model: a scikit-learn model to validate
params: the parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=dataset,
model=model,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Parameters class for the Deepchecks model validation validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]] |
Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepParameters(BaseParameters):
"""Parameters class for the Deepchecks model validation validator step.
Attributes:
check_list: Optional list of DeepchecksModelValidationCheck identifiers
specifying the subset of Deepchecks model validation checks to be
performed. If not supplied, the entire set of model validation checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, model, params)
Main entrypoint for the Deepchecks model validation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a Pandas DataFrame to use for the validation |
required |
model |
ClassifierMixin |
a scikit-learn model to validate |
required |
params |
DeepchecksModelValidationCheckStepParameters |
the parameters for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: ClassifierMixin,
params: DeepchecksModelValidationCheckStepParameters,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model validation step.
Args:
dataset: a Pandas DataFrame to use for the validation
model: a scikit-learn model to validate
params: the parameters for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=dataset,
model=model,
check_list=cast(Optional[Sequence[str]], params.check_list),
dataset_kwargs=params.dataset_kwargs,
check_kwargs=params.check_kwargs,
run_kwargs=params.run_kwargs,
)
DeepchecksModelValidationCheckStepParameters (BaseParameters)
pydantic-model
Parameters class for the Deepchecks model validation validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]] |
Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepParameters(BaseParameters):
"""Parameters class for the Deepchecks model validation validator step.
Attributes:
check_list: Optional list of DeepchecksModelValidationCheck identifiers
specifying the subset of Deepchecks model validation checks to be
performed. If not supplied, the entire set of model validation checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_validation_check_step(step_name, params)
Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.
The returned DeepchecksModelValidationCheckStep can be used in a pipeline to run model validation checks on an input pd.DataFrame dataset and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
DeepchecksModelValidationCheckStepParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksModelValidationCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def deepchecks_model_validation_check_step(
step_name: str,
params: DeepchecksModelValidationCheckStepParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.
The returned DeepchecksModelValidationCheckStep can be used in a pipeline to
run model validation checks on an input pd.DataFrame dataset and an input
scikit-learn ClassifierMixin model and return the results as a Deepchecks
SuiteResult object.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a DeepchecksModelValidationCheckStep step instance
"""
return clone_step(DeepchecksModelValidationCheckStep, step_name)(
params=params
)
validation_checks
Definition of the Deepchecks validation check types.
DeepchecksDataDriftCheck (DeepchecksValidationCheck)
Categories of Deepchecks data drift checks.
This list reflects the set of train-test validation checks provided by Deepchecks:
All these checks inherit from deepchecks.tabular.TrainTestCheck
or
deepchecks.vision.TrainTestCheck
and require two datasets as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks data drift checks.
This list reflects the set of train-test validation checks provided by
Deepchecks:
* [for tabular data](https://docs.deepchecks.com/stable/checks_gallery/tabular.html#train-test-validation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#train-test-validation)
All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
`deepchecks.vision.TrainTestCheck` and require two datasets as input.
"""
TABULAR_CATEGORY_MISMATCH_TRAIN_TEST = resolve_class(
tabular_checks.CategoryMismatchTrainTest
)
TABULAR_DATASET_SIZE_COMPARISON = resolve_class(
tabular_checks.DatasetsSizeComparison
)
TABULAR_DATE_TRAIN_TEST_LEAKAGE_DUPLICATES = resolve_class(
tabular_checks.DateTrainTestLeakageDuplicates
)
TABULAR_DATE_TRAIN_TEST_LEAKAGE_OVERLAP = resolve_class(
tabular_checks.DateTrainTestLeakageOverlap
)
TABULAR_DOMINANT_FREQUENCY_CHANGE = resolve_class(
tabular_checks.DominantFrequencyChange
)
TABULAR_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
tabular_checks.FeatureLabelCorrelationChange
)
TABULAR_INDEX_LEAKAGE = resolve_class(tabular_checks.IndexTrainTestLeakage)
TABULAR_NEW_LABEL_TRAIN_TEST = resolve_class(
tabular_checks.NewLabelTrainTest
)
TABULAR_STRING_MISMATCH_COMPARISON = resolve_class(
tabular_checks.StringMismatchComparison
)
TABULAR_TRAIN_TEST_FEATURE_DRIFT = resolve_class(
tabular_checks.TrainTestFeatureDrift
)
TABULAR_TRAIN_TEST_LABEL_DRIFT = resolve_class(
tabular_checks.TrainTestLabelDrift
)
TABULAR_TRAIN_TEST_SAMPLES_MIX = resolve_class(
tabular_checks.TrainTestSamplesMix
)
TABULAR_WHOLE_DATASET_DRIFT = resolve_class(
tabular_checks.WholeDatasetDrift
)
VISION_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
vision_checks.FeatureLabelCorrelationChange
)
VISION_HEATMAP_COMPARISON = resolve_class(vision_checks.HeatmapComparison)
VISION_IMAGE_DATASET_DRIFT = resolve_class(vision_checks.ImageDatasetDrift)
VISION_IMAGE_PROPERTY_DRIFT = resolve_class(
vision_checks.ImagePropertyDrift
)
VISION_NEW_LABELS = resolve_class(vision_checks.NewLabels)
VISION_SIMILAR_IMAGE_LEAKAGE = resolve_class(
vision_checks.SimilarImageLeakage
)
VISION_TRAIN_TEST_LABEL_DRIFT = resolve_class(
vision_checks.TrainTestLabelDrift
)
DeepchecksDataIntegrityCheck (DeepchecksValidationCheck)
Categories of Deepchecks data integrity checks.
This list reflects the set of data integrity checks provided by Deepchecks:
All these checks inherit from deepchecks.tabular.SingleDatasetCheck
or
deepchecks.vision.SingleDatasetCheck
and require a single dataset as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks data integrity checks.
This list reflects the set of data integrity checks provided by Deepchecks:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#data-integrity)
* [for computer vision](https://docs.deepchecks.com/en/stable/checks_gallery/vision.html#data-integrity)
All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
`deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
"""
TABULAR_COLUMNS_INFO = resolve_class(tabular_checks.ColumnsInfo)
TABULAR_CONFLICTING_LABELS = resolve_class(tabular_checks.ConflictingLabels)
TABULAR_DATA_DUPLICATES = resolve_class(tabular_checks.DataDuplicates)
TABULAR_FEATURE_FEATURE_CORRELATION = resolve_class(
FeatureFeatureCorrelation
)
TABULAR_FEATURE_LABEL_CORRELATION = resolve_class(
tabular_checks.FeatureLabelCorrelation
)
TABULAR_IDENTIFIER_LEAKAGE = resolve_class(tabular_checks.IdentifierLeakage)
TABULAR_IS_SINGLE_VALUE = resolve_class(tabular_checks.IsSingleValue)
TABULAR_MIXED_DATA_TYPES = resolve_class(tabular_checks.MixedDataTypes)
TABULAR_MIXED_NULLS = resolve_class(tabular_checks.MixedNulls)
TABULAR_OUTLIER_SAMPLE_DETECTION = resolve_class(
tabular_checks.OutlierSampleDetection
)
TABULAR_SPECIAL_CHARS = resolve_class(tabular_checks.SpecialCharacters)
TABULAR_STRING_LENGTH_OUT_OF_BOUNDS = resolve_class(
tabular_checks.StringLengthOutOfBounds
)
TABULAR_STRING_MISMATCH = resolve_class(tabular_checks.StringMismatch)
VISION_IMAGE_PROPERTY_OUTLIERS = resolve_class(
vision_checks.ImagePropertyOutliers
)
VISION_LABEL_PROPERTY_OUTLIERS = resolve_class(
vision_checks.LabelPropertyOutliers
)
DeepchecksModelDriftCheck (DeepchecksValidationCheck)
Categories of Deepchecks model drift checks.
This list includes a subset of the model evaluation checks provided by Deepchecks that require two datasets and a mandatory model as input:
All these checks inherit from deepchecks.tabular.TrainTestCheck
or
deepchecks.vision.TrainTestCheck
and require two datasets and a mandatory
model as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks model drift checks.
This list includes a subset of the model evaluation checks provided by
Deepchecks that require two datasets and a mandatory model as input:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)
All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
`deepchecks.vision.TrainTestCheck` and require two datasets and a mandatory
model as input.
"""
TABULAR_BOOSTING_OVERFIT = resolve_class(tabular_checks.BoostingOverfit)
TABULAR_MODEL_ERROR_ANALYSIS = resolve_class(
tabular_checks.ModelErrorAnalysis
)
TABULAR_PERFORMANCE_REPORT = resolve_class(tabular_checks.PerformanceReport)
TABULAR_SIMPLE_MODEL_COMPARISON = resolve_class(
tabular_checks.SimpleModelComparison
)
TABULAR_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
tabular_checks.TrainTestPredictionDrift
)
TABULAR_UNUSED_FEATURES = resolve_class(tabular_checks.UnusedFeatures)
VISION_CLASS_PERFORMANCE = resolve_class(vision_checks.ClassPerformance)
VISION_MODEL_ERROR_ANALYSIS = resolve_class(
vision_checks.ModelErrorAnalysis
)
VISION_SIMPLE_MODEL_COMPARISON = resolve_class(
vision_checks.SimpleModelComparison
)
VISION_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
vision_checks.TrainTestPredictionDrift
)
DeepchecksModelValidationCheck (DeepchecksValidationCheck)
Categories of Deepchecks model validation checks.
This list includes a subset of the model evaluation checks provided by Deepchecks that require a single dataset and a mandatory model as input:
All these checks inherit from deepchecks.tabular.SingleDatasetCheck
or
`deepchecks.vision.SingleDatasetCheck and require a dataset and a mandatory
model as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelValidationCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks model validation checks.
This list includes a subset of the model evaluation checks provided by
Deepchecks that require a single dataset and a mandatory model as input:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)
All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
`deepchecks.vision.SingleDatasetCheck and require a dataset and a mandatory
model as input.
"""
TABULAR_CALIBRATION_SCORE = resolve_class(tabular_checks.CalibrationScore)
TABULAR_CONFUSION_MATRIX_REPORT = resolve_class(
tabular_checks.ConfusionMatrixReport
)
TABULAR_MODEL_INFERENCE_TIME = resolve_class(
tabular_checks.ModelInferenceTime
)
TABULAR_REGRESSION_ERROR_DISTRIBUTION = resolve_class(
tabular_checks.RegressionErrorDistribution
)
TABULAR_REGRESSION_SYSTEMATIC_ERROR = resolve_class(
tabular_checks.RegressionSystematicError
)
TABULAR_ROC_REPORT = resolve_class(tabular_checks.RocReport)
TABULAR_SEGMENT_PERFORMANCE = resolve_class(
tabular_checks.SegmentPerformance
)
VISION_CONFUSION_MATRIX_REPORT = resolve_class(
vision_checks.ConfusionMatrixReport
)
VISION_IMAGE_SEGMENT_PERFORMANCE = resolve_class(
vision_checks.ImageSegmentPerformance
)
VISION_MEAN_AVERAGE_PRECISION_REPORT = resolve_class(
vision_checks.MeanAveragePrecisionReport
)
VISION_MEAN_AVERAGE_RECALL_REPORT = resolve_class(
vision_checks.MeanAverageRecallReport
)
VISION_ROBUSTNESS_REPORT = resolve_class(vision_checks.RobustnessReport)
VISION_SINGLE_DATASET_SCALAR_PERFORMANCE = resolve_class(
vision_checks.SingleDatasetScalarPerformance
)
DeepchecksValidationCheck (StrEnum)
Base class for all Deepchecks categories of validation checks.
This base class defines some conventions used for all enum values used to identify the various validation checks that can be performed with Deepchecks:
- enum values represent fully formed class paths pointing to Deepchecks BaseCheck subclasses
- all tabular data checks are located under the
deepchecks.tabular.checks
module sub-tree - all computer vision data checks are located under the
deepchecks.vision.checks
module sub-tree
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksValidationCheck(StrEnum):
"""Base class for all Deepchecks categories of validation checks.
This base class defines some conventions used for all enum values used to
identify the various validation checks that can be performed with
Deepchecks:
* enum values represent fully formed class paths pointing to Deepchecks
BaseCheck subclasses
* all tabular data checks are located under the
`deepchecks.tabular.checks` module sub-tree
* all computer vision data checks are located under the
`deepchecks.vision.checks` module sub-tree
"""
@classmethod
def validate_check_name(cls, check_name: str) -> None:
"""Validate a Deepchecks check identifier.
Args:
check_name: Identifies a builtin Deepchecks check. The identifier
must be formatted as `deepchecks.{tabular|vision}.checks.<...>.<class-name>`.
Raises:
ValueError: If the check identifier does not follow the convention
used by ZenML to identify Deepchecks builtin checks.
"""
if not re.match(
r"^deepchecks\.(tabular|vision)\.checks\.",
check_name,
):
raise ValueError(
f"The supplied Deepcheck check identifier does not follow the "
f"convention used by ZenML: `{check_name}`. The identifier "
f"must be formatted as `deepchecks.<tabular|vision>.checks...` "
f"and must be resolvable to a valid Deepchecks BaseCheck "
f"subclass."
)
@classmethod
def is_tabular_check(cls, check_name: str) -> bool:
"""Check if a validation check is applicable to tabular data.
Args:
check_name: Identifies a builtin Deepchecks check.
Returns:
True if the check is applicable to tabular data, otherwise False.
"""
cls.validate_check_name(check_name)
return check_name.startswith("deepchecks.tabular.")
@classmethod
def is_vision_check(cls, check_name: str) -> bool:
"""Check if a validation check is applicable to computer vision data.
Args:
check_name: Identifies a builtin Deepchecks check.
Returns:
True if the check is applicable to compute vision data, otherwise
False.
"""
cls.validate_check_name(check_name)
return check_name.startswith("deepchecks.vision.")
@classmethod
def get_check_class(cls, check_name: str) -> Type[BaseCheck]:
"""Get the Deepchecks check class associated with an enum value or a custom check name.
Args:
check_name: Identifies a builtin Deepchecks check. The identifier
must be formatted as `deepchecks.{tabular|vision}.checks.<class-name>`
and must be resolvable to a valid Deepchecks BaseCheck class.
Returns:
The Deepchecks check class associated with this enum value.
Raises:
ValueError: If the check name could not be converted to a valid
Deepchecks check class. This can happen for example if the enum
values fall out of sync with the Deepchecks code base or if a
custom check name is supplied that cannot be resolved to a valid
Deepchecks BaseCheck class.
"""
cls.validate_check_name(check_name)
try:
check_class = import_class_by_path(check_name)
except AttributeError:
raise ValueError(
f"Could not map the `{check_name}` check identifier to a valid "
f"Deepchecks check class."
)
if not issubclass(check_class, BaseCheck):
raise ValueError(
f"The `{check_name}` check identifier is mapped to an invalid "
f"data type. Expected a {str(BaseCheck)} subclass, but instead "
f"got: {str(check_class)}."
)
if check_name not in cls.values():
logger.warning(
f"You are using a custom Deepchecks check identifier that is "
f"not listed in the `{str(cls)}` enum type. This could lead "
f"to unexpected behavior."
)
return check_class
@property
def check_class(self) -> Type[BaseCheck]:
"""Convert the enum value to a valid Deepchecks check class.
Returns:
The Deepchecks check class associated with the enum value.
"""
return self.get_check_class(self.value)
visualizers
special
Deepchecks visualizer.
deepchecks_visualizer
Implementation of the Deepchecks visualizer.
DeepchecksVisualizer (BaseVisualizer)
The implementation of a Deepchecks Visualizer.
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
class DeepchecksVisualizer(BaseVisualizer):
"""The implementation of a Deepchecks Visualizer."""
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments (unused).
**kwargs: Additional keyword arguments (unused).
"""
for artifact_view in object.outputs.values():
# filter out anything but data analysis artifacts
if artifact_view.type == DataAnalysisArtifact.__name__:
artifact = artifact_view.read()
self.generate_report(artifact)
def generate_report(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Generate a Deepchecks Report.
Args:
result: A SuiteResult.
"""
print(result)
if Environment.in_notebook():
result.show()
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
result.save_as_html(f)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_report(self, result)
Generate a Deepchecks Report.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A SuiteResult. |
required |
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
def generate_report(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Generate a Deepchecks Report.
Args:
result: A SuiteResult.
"""
print(result)
if Environment.in_notebook():
result.show()
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
result.save_as_html(f)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)
Method to visualize components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
*args |
Any |
Additional arguments (unused). |
() |
**kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Source code in zenml/integrations/deepchecks/visualizers/deepchecks_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments (unused).
**kwargs: Additional keyword arguments (unused).
"""
for artifact_view in object.outputs.values():
# filter out anything but data analysis artifacts
if artifact_view.type == DataAnalysisArtifact.__name__:
artifact = artifact_view.read()
self.generate_report(artifact)
evidently
special
Initialization of the Evidently integration.
The Evidently integration provides a way to monitor your models in production. It includes a way to detect data drift and different kinds of model performance issues.
The results of Evidently calculations can either be exported as an interactive dashboard (visualized as an html file or in your Jupyter notebook), or as a JSON file.
EvidentlyIntegration (Integration)
Evidently integration for ZenML.
Source code in zenml/integrations/evidently/__init__.py
class EvidentlyIntegration(Integration):
"""[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML."""
NAME = EVIDENTLY
REQUIREMENTS = ["evidently==0.1.52dev0"]
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.evidently import materializers # noqa
from zenml.integrations.evidently import visualizers # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.evidently.flavors import (
EvidentlyDataValidatorFlavor,
)
return [EvidentlyDataValidatorFlavor]
activate()
staticmethod
Activate the Deepchecks integration.
Source code in zenml/integrations/evidently/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.evidently import materializers # noqa
from zenml.integrations.evidently import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Great Expectations integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/evidently/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.evidently.flavors import (
EvidentlyDataValidatorFlavor,
)
return [EvidentlyDataValidatorFlavor]
data_validators
special
Initialization of the Evidently data validator for ZenML.
evidently_data_validator
Implementation of the Evidently data validator.
EvidentlyDataValidator (BaseDataValidator)
Evidently data validator stack component.
Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
class EvidentlyDataValidator(BaseDataValidator):
"""Evidently data validator stack component."""
NAME: ClassVar[str] = "Evidently"
@classmethod
def _unpack_options(
cls, option_list: Sequence[Tuple[str, Dict[str, Any]]]
) -> Sequence[Any]:
"""Unpack Evidently options.
Implements de-serialization for [Evidently options](https://docs.evidentlyai.com/user-guide/customization)
that can be passed as constructor arguments when creating Profile and
Dashboard objects. The convention used is that each item in the list
consists of two elements:
* a string containing the full class path of a `dataclass` based
class with Evidently options
* a dictionary with kwargs used as parameters for the option instance
Example:
```python
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
```
This is the same as saying:
```python
from evidently.options import ColorOptions
color_scheme = ColorOptions()
color_scheme.primary_color = "#5a86ad"
color_scheme.fill_color = "#fff4f2"
color_scheme.zero_line_color = "#016795"
color_scheme.current_data_color = "#c292a1"
color_scheme.reference_data_color = "#017b92"
```
Args:
option_list: list of packed Evidently options
Returns:
A list of unpacked Evidently options
Raises:
ValueError: if one of the passed Evidently class paths cannot be
resolved to an actual class.
"""
options = []
for option_clspath, option_args in option_list:
try:
option_cls = load_source_path_class(option_clspath)
except AttributeError:
raise ValueError(
f"Could not map the `{option_clspath}` Evidently option "
f"class path to a valid class."
)
option = option_cls(**option_args)
options.append(option)
return options
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
column_mapping: Optional[ColumnMapping] = None,
verbose_level: int = 1,
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
**kwargs: Any,
) -> Tuple[Profile, Dashboard]:
"""Analyze a dataset and generate a data profile with Evidently.
The method takes in an optional list of Evidently options to be passed
to the profile constructor (`profile_options`) and the dashboard
constructor (`dashboard_options`). Each element in the list must be
composed of two items: the first is a full class path of an Evidently
option `dataclass`, the second is a dictionary of kwargs with the actual
option parameters, e.g.:
```python
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
```
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of Evidently
data profiles to be generated.
column_mapping: Properties of the DataFrame columns used
verbose_level: Level of verbosity for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor.
dashboard_options: Optional list of options to pass to the
dashboard constructor.
**kwargs: Extra keyword arguments (unused).
Returns:
The Evidently Profile and Dashboard objects corresponding to the set
of generated profiles.
"""
sections, tabs = get_profile_sections_and_tabs(
profile_list, verbose_level
)
unpacked_profile_options = self._unpack_options(profile_options)
unpacked_dashboard_options = self._unpack_options(dashboard_options)
dashboard = Dashboard(tabs=tabs, options=unpacked_dashboard_options)
dashboard.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
profile = Profile(sections=sections, options=unpacked_profile_options)
profile.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
return profile, dashboard
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, column_mapping=None, verbose_level=1, profile_options=[], dashboard_options=[], **kwargs)
Analyze a dataset and generate a data profile with Evidently.
The method takes in an optional list of Evidently options to be passed
to the profile constructor (profile_options
) and the dashboard
constructor (dashboard_options
). Each element in the list must be
composed of two items: the first is a full class path of an Evidently
option dataclass
, the second is a dictionary of kwargs with the actual
option parameters, e.g.:
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
Target dataset to be profiled. |
required |
comparison_dataset |
Optional[pandas.core.frame.DataFrame] |
Optional dataset to be used for data profiles that require a baseline for comparison (e.g data drift profiles). |
None |
profile_list |
Optional[Sequence[str]] |
Optional list identifying the categories of Evidently data profiles to be generated. |
None |
column_mapping |
Optional[evidently.pipeline.column_mapping.ColumnMapping] |
Properties of the DataFrame columns used |
None |
verbose_level |
int |
Level of verbosity for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
1 |
profile_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the profile constructor. |
[] |
dashboard_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the dashboard constructor. |
[] |
**kwargs |
Any |
Extra keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
Tuple[evidently.model_profile.model_profile.Profile, evidently.dashboard.dashboard.Dashboard] |
The Evidently Profile and Dashboard objects corresponding to the set of generated profiles. |
Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
column_mapping: Optional[ColumnMapping] = None,
verbose_level: int = 1,
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = [],
**kwargs: Any,
) -> Tuple[Profile, Dashboard]:
"""Analyze a dataset and generate a data profile with Evidently.
The method takes in an optional list of Evidently options to be passed
to the profile constructor (`profile_options`) and the dashboard
constructor (`dashboard_options`). Each element in the list must be
composed of two items: the first is a full class path of an Evidently
option `dataclass`, the second is a dictionary of kwargs with the actual
option parameters, e.g.:
```python
options = [
(
"evidently.options.ColorOptions",{
"primary_color": "#5a86ad",
"fill_color": "#fff4f2",
"zero_line_color": "#016795",
"current_data_color": "#c292a1",
"reference_data_color": "#017b92",
}
),
]
```
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of Evidently
data profiles to be generated.
column_mapping: Properties of the DataFrame columns used
verbose_level: Level of verbosity for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor.
dashboard_options: Optional list of options to pass to the
dashboard constructor.
**kwargs: Extra keyword arguments (unused).
Returns:
The Evidently Profile and Dashboard objects corresponding to the set
of generated profiles.
"""
sections, tabs = get_profile_sections_and_tabs(
profile_list, verbose_level
)
unpacked_profile_options = self._unpack_options(profile_options)
unpacked_dashboard_options = self._unpack_options(dashboard_options)
dashboard = Dashboard(tabs=tabs, options=unpacked_dashboard_options)
dashboard.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
profile = Profile(sections=sections, options=unpacked_profile_options)
profile.calculate(
reference_data=dataset,
current_data=comparison_dataset,
column_mapping=column_mapping,
)
return profile, dashboard
get_profile_sections_and_tabs(profile_list, verbose_level=1)
Get the profile sections and dashboard tabs for a profile list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile_list |
Optional[Sequence[str]] |
List of identifiers for Evidently profiles. |
required |
verbose_level |
int |
Verbosity level for the rendered dashboard. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
1 |
Returns:
Type | Description |
---|---|
Tuple[List[evidently.model_profile.sections.base_profile_section.ProfileSection], List[evidently.dashboard.tabs.base_tab.Tab]] |
A tuple of two lists of profile sections and tabs. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the profile_section is not supported. |
Source code in zenml/integrations/evidently/data_validators/evidently_data_validator.py
def get_profile_sections_and_tabs(
profile_list: Optional[Sequence[str]],
verbose_level: int = 1,
) -> Tuple[List[ProfileSection], List[Tab]]:
"""Get the profile sections and dashboard tabs for a profile list.
Args:
profile_list: List of identifiers for Evidently profiles.
verbose_level: Verbosity level for the rendered dashboard. Use
0 for a brief dashboard, 1 for a detailed dashboard.
Returns:
A tuple of two lists of profile sections and tabs.
Raises:
ValueError: if the profile_section is not supported.
"""
profile_list = profile_list or list(profile_mapper.keys())
try:
return (
[profile_mapper[profile]() for profile in profile_list],
[
dashboard_mapper[profile](verbose_level=verbose_level)
for profile in profile_list
],
)
except KeyError as e:
nl = "\n"
raise ValueError(
f"Invalid profile sections: {profile_list} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
) from e
flavors
special
Evidently integration flavors.
evidently_data_validator_flavor
Evidently data validator flavor.
EvidentlyDataValidatorFlavor (BaseDataValidatorFlavor)
Evidently data validator flavor.
Source code in zenml/integrations/evidently/flavors/evidently_data_validator_flavor.py
class EvidentlyDataValidatorFlavor(BaseDataValidatorFlavor):
"""Evidently data validator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return EVIDENTLY_DATA_VALIDATOR_FLAVOR
@property
def implementation_class(self) -> Type["EvidentlyDataValidator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.evidently.data_validators import (
EvidentlyDataValidator,
)
return EvidentlyDataValidator
implementation_class: Type[EvidentlyDataValidator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[EvidentlyDataValidator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
materializers
special
Evidently materializers.
evidently_profile_materializer
Implementation of Evidently profile materializer.
EvidentlyProfileMaterializer (BaseMaterializer)
Materializer to read data to and from an Evidently Profile.
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
class EvidentlyProfileMaterializer(BaseMaterializer):
"""Materializer to read data to and from an Evidently Profile."""
ASSOCIATED_TYPES = (Profile,)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
def handle_input(self, data_type: Type[Any]) -> Profile:
"""Reads an Evidently Profile object from a json file.
Args:
data_type: The type of the data to read.
Returns:
The Evidently Profile
Raises:
TypeError: if the json file contains an invalid data type.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
contents = yaml_utils.read_json(filepath)
if type(contents) != dict:
raise TypeError(
f"Contents {contents} was type {type(contents)} but expected "
f"dictionary"
)
section_types = contents.pop("section_types", [])
sections = []
for section_type in section_types:
section_cls = import_class_by_path(section_type)
section = section_cls()
section._result = contents[section.part_id()]
sections.append(section)
return Profile(sections=sections)
def handle_return(self, data: Profile) -> None:
"""Serialize an Evidently Profile to a json file.
Args:
data: The Evidently Profile to be serialized.
"""
super().handle_return(data)
contents = data.object()
# include the list of profile sections in the serialized dictionary,
# so we'll be able to re-create them during de-serialization
contents["section_types"] = [
resolve_class(stage.__class__) for stage in data.stages
]
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
yaml_utils.write_json(filepath, contents, encoder=NumpyEncoder)
handle_input(self, data_type)
Reads an Evidently Profile object from a json file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Profile |
The Evidently Profile |
Exceptions:
Type | Description |
---|---|
TypeError |
if the json file contains an invalid data type. |
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
def handle_input(self, data_type: Type[Any]) -> Profile:
"""Reads an Evidently Profile object from a json file.
Args:
data_type: The type of the data to read.
Returns:
The Evidently Profile
Raises:
TypeError: if the json file contains an invalid data type.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
contents = yaml_utils.read_json(filepath)
if type(contents) != dict:
raise TypeError(
f"Contents {contents} was type {type(contents)} but expected "
f"dictionary"
)
section_types = contents.pop("section_types", [])
sections = []
for section_type in section_types:
section_cls = import_class_by_path(section_type)
section = section_cls()
section._result = contents[section.part_id()]
sections.append(section)
return Profile(sections=sections)
handle_return(self, data)
Serialize an Evidently Profile to a json file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Profile |
The Evidently Profile to be serialized. |
required |
Source code in zenml/integrations/evidently/materializers/evidently_profile_materializer.py
def handle_return(self, data: Profile) -> None:
"""Serialize an Evidently Profile to a json file.
Args:
data: The Evidently Profile to be serialized.
"""
super().handle_return(data)
contents = data.object()
# include the list of profile sections in the serialized dictionary,
# so we'll be able to re-create them during de-serialization
contents["section_types"] = [
resolve_class(stage.__class__) for stage in data.stages
]
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
yaml_utils.write_json(filepath, contents, encoder=NumpyEncoder)
steps
special
Initialization of the Evidently Standard Steps.
evidently_profile
Implementation of the Evidently Profile Step.
EvidentlyColumnMapping (BaseModel)
pydantic-model
Column mapping configuration for Evidently.
This class is a 1-to-1 serializable analog of Evidently's ColumnMapping data type that can be used as a step configuration field (see https://docs.evidentlyai.com/features/dashboards/column_mapping).
Attributes:
Name | Type | Description |
---|---|---|
target |
Optional[str] |
target column |
prediction |
Union[str, Sequence[str]] |
target column |
datetime |
Optional[str] |
datetime column |
id |
Optional[str] |
id column |
numerical_features |
Optional[List[str]] |
numerical features |
categorical_features |
Optional[List[str]] |
categorical features |
datetime_features |
Optional[List[str]] |
datetime features |
target_names |
Optional[List[str]] |
target column names |
task |
Optional[Literal['classification', 'regression']] |
model task (regression or classification) |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyColumnMapping(BaseModel):
"""Column mapping configuration for Evidently.
This class is a 1-to-1 serializable analog of Evidently's
ColumnMapping data type that can be used as a step configuration field
(see https://docs.evidentlyai.com/features/dashboards/column_mapping).
Attributes:
target: target column
prediction: target column
datetime: datetime column
id: id column
numerical_features: numerical features
categorical_features: categorical features
datetime_features: datetime features
target_names: target column names
task: model task (regression or classification)
"""
target: Optional[str] = None
prediction: Optional[Union[str, Sequence[str]]] = None
datetime: Optional[str] = None
id: Optional[str] = None
numerical_features: Optional[List[str]] = None
categorical_features: Optional[List[str]] = None
datetime_features: Optional[List[str]] = None
target_names: Optional[List[str]] = None
task: Optional[Literal["classification", "regression"]] = None
def to_evidently_column_mapping(self) -> ColumnMapping:
"""Convert this Pydantic object to an Evidently ColumnMapping object.
Returns:
An Evidently column mapping converted from this Pydantic object.
"""
column_mapping = ColumnMapping()
# preserve the Evidently defaults where possible
column_mapping.target = self.target or column_mapping.target
column_mapping.prediction = self.prediction or column_mapping.prediction
column_mapping.datetime = self.datetime or column_mapping.datetime
column_mapping.id = self.id or column_mapping.id
column_mapping.numerical_features = (
self.numerical_features or column_mapping.numerical_features
)
column_mapping.datetime_features = (
self.datetime_features or column_mapping.datetime_features
)
column_mapping.target_names = (
self.target_names or column_mapping.target_names
)
column_mapping.task = self.task or column_mapping.task
return column_mapping
to_evidently_column_mapping(self)
Convert this Pydantic object to an Evidently ColumnMapping object.
Returns:
Type | Description |
---|---|
ColumnMapping |
An Evidently column mapping converted from this Pydantic object. |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def to_evidently_column_mapping(self) -> ColumnMapping:
"""Convert this Pydantic object to an Evidently ColumnMapping object.
Returns:
An Evidently column mapping converted from this Pydantic object.
"""
column_mapping = ColumnMapping()
# preserve the Evidently defaults where possible
column_mapping.target = self.target or column_mapping.target
column_mapping.prediction = self.prediction or column_mapping.prediction
column_mapping.datetime = self.datetime or column_mapping.datetime
column_mapping.id = self.id or column_mapping.id
column_mapping.numerical_features = (
self.numerical_features or column_mapping.numerical_features
)
column_mapping.datetime_features = (
self.datetime_features or column_mapping.datetime_features
)
column_mapping.target_names = (
self.target_names or column_mapping.target_names
)
column_mapping.task = self.task or column_mapping.task
return column_mapping
EvidentlyProfileParameters (BaseDriftDetectionParameters)
pydantic-model
Parameters class for Evidently profile steps.
Attributes:
Name | Type | Description |
---|---|---|
column_mapping |
Optional[zenml.integrations.evidently.steps.evidently_profile.EvidentlyColumnMapping] |
properties of the DataFrame columns used |
ignored_cols |
Optional[List[str]] |
columns to ignore during the Evidently profile step |
profile_sections |
Optional[Sequence[str]] |
a list identifying the Evidently profile sections to be used. The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance" |
verbose_level |
int |
Verbosity level for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
profile_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
profile constructor. See |
dashboard_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
dashboard constructor. See |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileParameters(BaseDriftDetectionParameters):
"""Parameters class for Evidently profile steps.
Attributes:
column_mapping: properties of the DataFrame columns used
ignored_cols: columns to ignore during the Evidently profile step
profile_sections: a list identifying the Evidently profile sections to be
used. The following are valid options supported by Evidently:
- "datadrift"
- "categoricaltargetdrift"
- "numericaltargetdrift"
- "classificationmodelperformance"
- "regressionmodelperformance"
- "probabilisticmodelperformance"
verbose_level: Verbosity level for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor. See `EvidentlyDataValidator._unpack_options`.
dashboard_options: Optional list of options to pass to the
dashboard constructor. See `EvidentlyDataValidator._unpack_options`.
"""
column_mapping: Optional[EvidentlyColumnMapping] = None
ignored_cols: Optional[List[str]] = None
profile_sections: Optional[Sequence[str]] = None
verbose_level: int = 1
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
EvidentlyProfileStep (BaseDriftDetectionStep)
Step implementation implementing an Evidently Profile Step.
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileStep(BaseDriftDetectionStep):
"""Step implementation implementing an Evidently Profile Step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
comparison_dataset: pd.DataFrame,
params: EvidentlyProfileParameters,
) -> Output( # type:ignore[valid-type]
profile=Profile, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection step.
Args:
reference_dataset: a Pandas DataFrame
comparison_dataset: a Pandas DataFrame of new data you wish to
compare against the reference data
params: the parameters for the step
Raises:
ValueError: If ignored_cols is an empty list
ValueError: If column is not found in reference or comparison
dataset
Returns:
profile: Evidently Profile generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""
data_validator = cast(
EvidentlyDataValidator,
EvidentlyDataValidator.get_active_data_validator(),
)
column_mapping = None
if params.ignored_cols is None:
pass
elif not params.ignored_cols:
raise ValueError(
f"Expects None or list of columns in strings, but got {params.ignored_cols}"
)
elif not (
set(params.ignored_cols).issubset(set(reference_dataset.columns))
) or not (
set(params.ignored_cols).issubset(set(comparison_dataset.columns))
):
raise ValueError(
"Column is not found in reference or comparison datasets"
)
else:
reference_dataset = reference_dataset.drop(
labels=list(params.ignored_cols), axis=1
)
comparison_dataset = comparison_dataset.drop(
labels=list(params.ignored_cols), axis=1
)
if params.column_mapping:
column_mapping = params.column_mapping.to_evidently_column_mapping()
profile, dashboard = data_validator.data_profiling(
dataset=reference_dataset,
comparison_dataset=comparison_dataset,
profile_list=params.profile_sections,
column_mapping=column_mapping,
verbose_level=params.verbose_level,
profile_options=params.profile_options,
dashboard_options=params.dashboard_options,
)
return [profile, dashboard.html()]
PARAMETERS_CLASS (BaseDriftDetectionParameters)
pydantic-model
Parameters class for Evidently profile steps.
Attributes:
Name | Type | Description |
---|---|---|
column_mapping |
Optional[zenml.integrations.evidently.steps.evidently_profile.EvidentlyColumnMapping] |
properties of the DataFrame columns used |
ignored_cols |
Optional[List[str]] |
columns to ignore during the Evidently profile step |
profile_sections |
Optional[Sequence[str]] |
a list identifying the Evidently profile sections to be used. The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance" |
verbose_level |
int |
Verbosity level for the Evidently dashboards. Use 0 for a brief dashboard, 1 for a detailed dashboard. |
profile_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
profile constructor. See |
dashboard_options |
Sequence[Tuple[str, Dict[str, Any]]] |
Optional list of options to pass to the
dashboard constructor. See |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileParameters(BaseDriftDetectionParameters):
"""Parameters class for Evidently profile steps.
Attributes:
column_mapping: properties of the DataFrame columns used
ignored_cols: columns to ignore during the Evidently profile step
profile_sections: a list identifying the Evidently profile sections to be
used. The following are valid options supported by Evidently:
- "datadrift"
- "categoricaltargetdrift"
- "numericaltargetdrift"
- "classificationmodelperformance"
- "regressionmodelperformance"
- "probabilisticmodelperformance"
verbose_level: Verbosity level for the Evidently dashboards. Use
0 for a brief dashboard, 1 for a detailed dashboard.
profile_options: Optional list of options to pass to the
profile constructor. See `EvidentlyDataValidator._unpack_options`.
dashboard_options: Optional list of options to pass to the
dashboard constructor. See `EvidentlyDataValidator._unpack_options`.
"""
column_mapping: Optional[EvidentlyColumnMapping] = None
ignored_cols: Optional[List[str]] = None
profile_sections: Optional[Sequence[str]] = None
verbose_level: int = 1
profile_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
dashboard_options: Sequence[Tuple[str, Dict[str, Any]]] = Field(
default_factory=list
)
entrypoint(self, reference_dataset, comparison_dataset, params)
Main entrypoint for the Evidently categorical target drift detection step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
a Pandas DataFrame |
required |
comparison_dataset |
DataFrame |
a Pandas DataFrame of new data you wish to compare against the reference data |
required |
params |
EvidentlyProfileParameters |
the parameters for the step |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If ignored_cols is an empty list |
ValueError |
If column is not found in reference or comparison dataset |
Returns:
Type | Description |
---|---|
profile |
Evidently Profile generated for the data drift dashboard: HTML report extracted from an Evidently Dashboard generated for the data drift |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
comparison_dataset: pd.DataFrame,
params: EvidentlyProfileParameters,
) -> Output( # type:ignore[valid-type]
profile=Profile, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection step.
Args:
reference_dataset: a Pandas DataFrame
comparison_dataset: a Pandas DataFrame of new data you wish to
compare against the reference data
params: the parameters for the step
Raises:
ValueError: If ignored_cols is an empty list
ValueError: If column is not found in reference or comparison
dataset
Returns:
profile: Evidently Profile generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""
data_validator = cast(
EvidentlyDataValidator,
EvidentlyDataValidator.get_active_data_validator(),
)
column_mapping = None
if params.ignored_cols is None:
pass
elif not params.ignored_cols:
raise ValueError(
f"Expects None or list of columns in strings, but got {params.ignored_cols}"
)
elif not (
set(params.ignored_cols).issubset(set(reference_dataset.columns))
) or not (
set(params.ignored_cols).issubset(set(comparison_dataset.columns))
):
raise ValueError(
"Column is not found in reference or comparison datasets"
)
else:
reference_dataset = reference_dataset.drop(
labels=list(params.ignored_cols), axis=1
)
comparison_dataset = comparison_dataset.drop(
labels=list(params.ignored_cols), axis=1
)
if params.column_mapping:
column_mapping = params.column_mapping.to_evidently_column_mapping()
profile, dashboard = data_validator.data_profiling(
dataset=reference_dataset,
comparison_dataset=comparison_dataset,
profile_list=params.profile_sections,
column_mapping=column_mapping,
verbose_level=params.verbose_level,
profile_options=params.profile_options,
dashboard_options=params.dashboard_options,
)
return [profile, dashboard.html()]
evidently_profile_step(step_name, params)
Shortcut function to create a new instance of the EvidentlyProfileConfig step.
The returned EvidentlyProfileStep can be used in a pipeline to run model drift analyses on two input pd.DataFrame datasets and return the results as an Evidently profile object and a rendered dashboard object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
EvidentlyProfileParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a EvidentlyProfileStep step instance |
Source code in zenml/integrations/evidently/steps/evidently_profile.py
def evidently_profile_step(
step_name: str,
params: EvidentlyProfileParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the EvidentlyProfileConfig step.
The returned EvidentlyProfileStep can be used in a pipeline to
run model drift analyses on two input pd.DataFrame datasets and return the
results as an Evidently profile object and a rendered dashboard object.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a EvidentlyProfileStep step instance
"""
return clone_step(EvidentlyProfileStep, step_name)(params=params)
visualizers
special
Initialization for Evidently visualizer.
evidently_visualizer
Implementation of the Evidently visualizer.
EvidentlyVisualizer (BaseVisualizer)
The implementation of an Evidently Visualizer.
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
class EvidentlyVisualizer(BaseVisualizer):
"""The implementation of an Evidently Visualizer."""
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but data artifacts
if (
artifact_view.type == DataArtifact.__name__
and artifact_view.data_type == "builtins.str"
):
artifact = artifact_view.read()
self.generate_facet(artifact)
def generate_facet(self, html_: str) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
"""
if Environment.in_notebook() or Environment.in_google_colab():
from IPython.core.display import HTML, display
display(HTML(html_))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_facet(self, html_)
Generate a Facet Overview.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
html_ |
str |
HTML represented as a string. |
required |
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
def generate_facet(self, html_: str) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
"""
if Environment.in_notebook() or Environment.in_google_colab():
from IPython.core.display import HTML, display
display(HTML(html_))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)
Method to visualize components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize components.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but data artifacts
if (
artifact_view.type == DataArtifact.__name__
and artifact_view.data_type == "builtins.str"
):
artifact = artifact_view.read()
self.generate_facet(artifact)
facets
special
Facets integration for ZenML.
The Facets integration provides a simple way to visualize post-execution objects
like PipelineView
, PipelineRunView
and StepView
. These objects can be
extended using the BaseVisualization
class. This integration requires
facets-overview
be installed in your Python environment.
FacetsIntegration (Integration)
Definition of Facet integration for ZenML.
Source code in zenml/integrations/facets/__init__.py
class FacetsIntegration(Integration):
"""Definition of [Facet](https://pair-code.github.io/facets/) integration for ZenML."""
NAME = FACETS
REQUIREMENTS = ["facets-overview>=1.0.0", "IPython"]
visualizers
special
Initialization of the Facet Visualizer.
facet_statistics_visualizer
Implementation of the Facet Statistics Visualizer.
FacetStatisticsVisualizer (BaseVisualizer)
Visualize and compare dataset statistics with Facets.
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
class FacetStatisticsVisualizer(BaseVisualizer):
"""Visualize and compare dataset statistics with Facets."""
@abstractmethod
def visualize(
self,
object: Union[StepView, Dict[str, Union[ArtifactView, pd.DataFrame]]],
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""Method to visualize components.
Args:
object: Either a StepView fetched from run.get_step() whose outputs
are all datasets that should be visualized, or a dict that maps
dataset names to datasets.
magic: Whether to render in a Jupyter notebook or not.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
data_dict = object.outputs if isinstance(object, StepView) else object
datasets = []
for dataset_name, data in data_dict.items():
df = data.read() if isinstance(data, ArtifactView) else data
if type(df) is not pd.DataFrame:
logger.warning(
"`%s` is not a pd.DataFrame. You can only visualize "
"statistics of steps that output pandas DataFrames. "
"Skipping this output.." % dataset_name
)
else:
datasets.append({"name": dataset_name, "table": df})
html_ = self.generate_html(datasets)
self.generate_facet(html_, magic)
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
"""Generates html for facet.
Args:
datasets: List of dicts of DataFrames to be visualized as stats.
Returns:
HTML template with proto string embedded.
"""
proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
datasets
)
protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")
template = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"stats.html",
)
html_template = io_utils.read_file_contents_as_string(template)
html_ = html_template.replace("protostr", protostr)
return html_
def generate_facet(self, html_: str, magic: bool = False) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
magic: Whether to magically materialize facet in a notebook.
Raises:
EnvironmentError: If magic is True and not in a notebook.
"""
if magic:
if not (Environment.in_notebook() or Environment.in_google_colab()):
raise EnvironmentError(
"The magic functions are only usable in a Jupyter notebook."
)
display(HTML(html_))
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
io_utils.write_file_contents_as_string(f.name, html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_facet(self, html_, magic=False)
Generate a Facet Overview.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
html_ |
str |
HTML represented as a string. |
required |
magic |
bool |
Whether to magically materialize facet in a notebook. |
False |
Exceptions:
Type | Description |
---|---|
EnvironmentError |
If magic is True and not in a notebook. |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_facet(self, html_: str, magic: bool = False) -> None:
"""Generate a Facet Overview.
Args:
html_: HTML represented as a string.
magic: Whether to magically materialize facet in a notebook.
Raises:
EnvironmentError: If magic is True and not in a notebook.
"""
if magic:
if not (Environment.in_notebook() or Environment.in_google_colab()):
raise EnvironmentError(
"The magic functions are only usable in a Jupyter notebook."
)
display(HTML(html_))
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
io_utils.write_file_contents_as_string(f.name, html_)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
generate_html(self, datasets)
Generates html for facet.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets |
List[Dict[str, pandas.core.frame.DataFrame]] |
List of dicts of DataFrames to be visualized as stats. |
required |
Returns:
Type | Description |
---|---|
str |
HTML template with proto string embedded. |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
"""Generates html for facet.
Args:
datasets: List of dicts of DataFrames to be visualized as stats.
Returns:
HTML template with proto string embedded.
"""
proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
datasets
)
protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")
template = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"stats.html",
)
html_template = io_utils.read_file_contents_as_string(template)
html_ = html_template.replace("protostr", protostr)
return html_
visualize(self, object, magic=False, *args, **kwargs)
Method to visualize components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
Union[zenml.post_execution.step.StepView, Dict[str, Union[zenml.post_execution.artifact.ArtifactView, pandas.core.frame.DataFrame]]] |
Either a StepView fetched from run.get_step() whose outputs are all datasets that should be visualized, or a dict that maps dataset names to datasets. |
required |
magic |
bool |
Whether to render in a Jupyter notebook or not. |
False |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
@abstractmethod
def visualize(
self,
object: Union[StepView, Dict[str, Union[ArtifactView, pd.DataFrame]]],
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""Method to visualize components.
Args:
object: Either a StepView fetched from run.get_step() whose outputs
are all datasets that should be visualized, or a dict that maps
dataset names to datasets.
magic: Whether to render in a Jupyter notebook or not.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
data_dict = object.outputs if isinstance(object, StepView) else object
datasets = []
for dataset_name, data in data_dict.items():
df = data.read() if isinstance(data, ArtifactView) else data
if type(df) is not pd.DataFrame:
logger.warning(
"`%s` is not a pd.DataFrame. You can only visualize "
"statistics of steps that output pandas DataFrames. "
"Skipping this output.." % dataset_name
)
else:
datasets.append({"name": dataset_name, "table": df})
html_ = self.generate_html(datasets)
self.generate_facet(html_, magic)
feast
special
Initialization for Feast integration.
The Feast integration offers a way to connect to a Feast Feature Store. ZenML implements a dedicated stack component that you can access as part of your ZenML steps in the usual ways.
FeastIntegration (Integration)
Definition of Feast integration for ZenML.
Source code in zenml/integrations/feast/__init__.py
class FeastIntegration(Integration):
"""Definition of Feast integration for ZenML."""
NAME = FEAST
REQUIREMENTS = ["feast[redis]~=0.26.0", "redis-server"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Feast integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.feast.flavors import FeastFeatureStoreFlavor
return [FeastFeatureStoreFlavor]
flavors()
classmethod
Declare the stack component flavors for the Feast integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/feast/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Feast integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.feast.flavors import FeastFeatureStoreFlavor
return [FeastFeatureStoreFlavor]
feature_stores
special
Feast Feature Store integration for ZenML.
Feature stores allow data teams to serve data via an offline store and an online low-latency store where data is kept in sync between the two. It also offers a centralized registry where features (and feature schemas) are stored for use within a team or wider organization. Feature stores are a relatively recent addition to commonly-used machine learning stacks. Feast is a leading open-source feature store, first developed by Gojek in collaboration with Google.
feast_feature_store
Implementation of the Feast Feature Store for ZenML.
FeastFeatureStore (BaseFeatureStore)
Class to interact with the Feast feature store.
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
class FeastFeatureStore(BaseFeatureStore):
"""Class to interact with the Feast feature store."""
@property
def config(self) -> FeastFeatureStoreConfig:
"""Returns the `FeastFeatureStoreConfig` config.
Returns:
The configuration.
"""
return cast(FeastFeatureStoreConfig, self._config)
def _validate_connection(self) -> None:
"""Validates the connection to the feature store.
Raises:
ConnectionError: If the online component (Redis) is not available.
"""
client = redis.Redis(
host=self.config.online_host, port=self.config.online_port
)
try:
client.ping()
except redis.exceptions.ConnectionError as e:
raise redis.exceptions.ConnectionError(
"Could not connect to feature store's online component. "
"Please make sure that Redis is running."
) from e
def get_historical_features(
self,
entity_df: Union[pd.DataFrame, str],
features: List[str],
full_feature_names: bool = False,
) -> pd.DataFrame:
"""Returns the historical features for training or batch scoring.
Args:
entity_df: The entity DataFrame or entity name.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The historical features as a Pandas DataFrame.
"""
fs = FeatureStore(repo_path=self.config.feast_repo)
return fs.get_historical_features(
entity_df=entity_df,
features=features,
full_feature_names=full_feature_names,
).to_df()
def get_online_features(
self,
entity_rows: List[Dict[str, Any]],
features: List[str],
full_feature_names: bool = False,
) -> Dict[str, Any]:
"""Returns the latest online feature data.
Args:
entity_rows: The entity rows to retrieve.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The latest online feature data as a dictionary.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return fs.get_online_features( # type: ignore[no-any-return]
entity_rows=entity_rows,
features=features,
full_feature_names=full_feature_names,
).to_dict()
def get_data_sources(self) -> List[str]:
"""Returns the data sources' names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The data sources' names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_data_sources()]
def get_entities(self) -> List[str]:
"""Returns the entity names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The entity names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_entities()]
def get_feature_services(self) -> List[str]:
"""Returns the feature service names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature service names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_feature_services()]
def get_feature_views(self) -> List[str]:
"""Returns the feature view names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature view names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_feature_views()]
def get_project(self) -> str:
"""Returns the project name.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The project name.
"""
fs = FeatureStore(repo_path=self.config.feast_repo)
return str(fs.project)
def get_registry(self) -> Registry:
"""Returns the feature store registry.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The registry.
"""
fs: FeatureStore = FeatureStore(repo_path=self.config.feast_repo)
return fs.registry
def get_feast_version(self) -> str:
"""Returns the version of Feast used.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The version of Feast currently being used.
"""
fs = FeatureStore(repo_path=self.config.feast_repo)
return str(fs.version())
config: FeastFeatureStoreConfig
property
readonly
Returns the FeastFeatureStoreConfig
config.
Returns:
Type | Description |
---|---|
FeastFeatureStoreConfig |
The configuration. |
get_data_sources(self)
Returns the data sources' names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The data sources' names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_data_sources(self) -> List[str]:
"""Returns the data sources' names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The data sources' names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_data_sources()]
get_entities(self)
Returns the entity names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The entity names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_entities(self) -> List[str]:
"""Returns the entity names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The entity names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_entities()]
get_feast_version(self)
Returns the version of Feast used.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
str |
The version of Feast currently being used. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feast_version(self) -> str:
"""Returns the version of Feast used.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The version of Feast currently being used.
"""
fs = FeatureStore(repo_path=self.config.feast_repo)
return str(fs.version())
get_feature_services(self)
Returns the feature service names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The feature service names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_services(self) -> List[str]:
"""Returns the feature service names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature service names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_feature_services()]
get_feature_views(self)
Returns the feature view names.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
List[str] |
The feature view names. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_views(self) -> List[str]:
"""Returns the feature view names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The feature view names.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return [ds.name for ds in fs.list_feature_views()]
get_historical_features(self, entity_df, features, full_feature_names=False)
Returns the historical features for training or batch scoring.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_df |
Union[pandas.core.frame.DataFrame, str] |
The entity DataFrame or entity name. |
required |
features |
List[str] |
The features to retrieve. |
required |
full_feature_names |
bool |
Whether to return the full feature names. |
False |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
DataFrame |
The historical features as a Pandas DataFrame. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_historical_features(
self,
entity_df: Union[pd.DataFrame, str],
features: List[str],
full_feature_names: bool = False,
) -> pd.DataFrame:
"""Returns the historical features for training or batch scoring.
Args:
entity_df: The entity DataFrame or entity name.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The historical features as a Pandas DataFrame.
"""
fs = FeatureStore(repo_path=self.config.feast_repo)
return fs.get_historical_features(
entity_df=entity_df,
features=features,
full_feature_names=full_feature_names,
).to_df()
get_online_features(self, entity_rows, features, full_feature_names=False)
Returns the latest online feature data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_rows |
List[Dict[str, Any]] |
The entity rows to retrieve. |
required |
features |
List[str] |
The features to retrieve. |
required |
full_feature_names |
bool |
Whether to return the full feature names. |
False |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The latest online feature data as a dictionary. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_online_features(
self,
entity_rows: List[Dict[str, Any]],
features: List[str],
full_feature_names: bool = False,
) -> Dict[str, Any]:
"""Returns the latest online feature data.
Args:
entity_rows: The entity rows to retrieve.
features: The features to retrieve.
full_feature_names: Whether to return the full feature names.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The latest online feature data as a dictionary.
"""
self._validate_connection()
fs = FeatureStore(repo_path=self.config.feast_repo)
return fs.get_online_features( # type: ignore[no-any-return]
entity_rows=entity_rows,
features=features,
full_feature_names=full_feature_names,
).to_dict()
get_project(self)
Returns the project name.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
str |
The project name. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_project(self) -> str:
"""Returns the project name.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The project name.
"""
fs = FeatureStore(repo_path=self.config.feast_repo)
return str(fs.project)
get_registry(self)
Returns the feature store registry.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the online component (Redis) is not available. |
Returns:
Type | Description |
---|---|
Registry |
The registry. |
Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_registry(self) -> Registry:
"""Returns the feature store registry.
Raise:
ConnectionError: If the online component (Redis) is not available.
Returns:
The registry.
"""
fs: FeatureStore = FeatureStore(repo_path=self.config.feast_repo)
return fs.registry
flavors
special
Feast integration flavors.
feast_feature_store_flavor
Feast feature store flavor.
FeastFeatureStoreConfig (BaseFeatureStoreConfig)
pydantic-model
Config for Feast feature store.
Source code in zenml/integrations/feast/flavors/feast_feature_store_flavor.py
class FeastFeatureStoreConfig(BaseFeatureStoreConfig):
"""Config for Feast feature store."""
online_host: str = "localhost"
online_port: int = 6379
feast_repo: str
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
This designation is used to determine if the stack component can be
shared with other users or if it is only usable on the local host.
Returns:
True if this config is for a local component, False otherwise.
"""
return (
self.online_host == "localhost" or self.online_host == "127.0.0.1"
)
is_local: bool
property
readonly
Checks if this stack component is running locally.
This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
FeastFeatureStoreFlavor (BaseFeatureStoreFlavor)
Feast Feature store flavor.
Source code in zenml/integrations/feast/flavors/feast_feature_store_flavor.py
class FeastFeatureStoreFlavor(BaseFeatureStoreFlavor):
"""Feast Feature store flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return FEAST_FEATURE_STORE_FLAVOR
@property
def config_class(self) -> Type[FeastFeatureStoreConfig]:
"""Returns FeastFeatureStoreConfig config class.
Returns:
The config class.
"""
"""Config class for this flavor."""
return FeastFeatureStoreConfig
@property
def implementation_class(self) -> Type["FeastFeatureStore"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.feast.feature_stores import FeastFeatureStore
return FeastFeatureStore
config_class: Type[zenml.integrations.feast.flavors.feast_feature_store_flavor.FeastFeatureStoreConfig]
property
readonly
Returns FeastFeatureStoreConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.feast.flavors.feast_feature_store_flavor.FeastFeatureStoreConfig] |
The config class. |
implementation_class: Type[FeastFeatureStore]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[FeastFeatureStore] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
gcp
special
Initialization of the GCP ZenML integration.
The GCP integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores
and provides an io
module to handle file operations on Google Cloud Storage
(GCS).
Additionally, the GCP secrets manager integration submodule provides a way to access the GCP secrets manager from within your ZenML Pipeline runs.
The Vertex AI integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.
GcpIntegration (Integration)
Definition of Google Cloud Platform integration for ZenML.
Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
"""Definition of Google Cloud Platform integration for ZenML."""
NAME = GCP
REQUIREMENTS = [
"kfp==1.8.13",
"gcsfs",
"google-cloud-secret-manager",
"google-cloud-aiplatform>=1.11.0",
]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the GCP integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.gcp.flavors import (
GCPArtifactStoreFlavor,
GCPSecretsManagerFlavor,
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
)
return [
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
GCPSecretsManagerFlavor,
GCPArtifactStoreFlavor,
]
flavors()
classmethod
Declare the stack component flavors for the GCP integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/gcp/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the GCP integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.gcp.flavors import (
GCPArtifactStoreFlavor,
GCPSecretsManagerFlavor,
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
)
return [
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
GCPSecretsManagerFlavor,
GCPArtifactStoreFlavor,
]
artifact_stores
special
Initialization of the GCP Artifact Store.
gcp_artifact_store
Implementation of the GCP Artifact Store.
GCPArtifactStore (BaseArtifactStore, AuthenticationMixin)
Artifact Store for Google Cloud Storage based artifacts.
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for Google Cloud Storage based artifacts."""
_filesystem: Optional[gcsfs.GCSFileSystem] = None
@property
def config(self) -> GCPArtifactStoreConfig:
"""Returns the `GCPArtifactStoreConfig` config.
Returns:
The configuration.
"""
return cast(GCPArtifactStoreConfig, self._config)
@property
def filesystem(self) -> gcsfs.GCSFileSystem:
"""The gcsfs filesystem to access this artifact store.
Returns:
The gcsfs filesystem to access this artifact store.
"""
if not self._filesystem:
secret = self.get_authentication_secret(
expected_schema_type=GCPSecretSchema
)
token = secret.get_credential_dict() if secret else None
self._filesystem = gcsfs.GCSFileSystem(token=token)
return self._filesystem
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object that can be used to read or write to the file.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [
f"{GCP_PATH_PREFIX}{path}"
for path in self.filesystem.glob(path=pattern)
]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path of the directory to list.
Returns:
A list of paths of files in the directory.
"""
path_without_prefix = convert_to_str(path)
if path_without_prefix.startswith(GCP_PATH_PREFIX):
path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by GCP.
Args:
file_dict: A file info dict returned by the GCP filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path_without_prefix) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: the path to get stat info for.
Returns:
A dictionary with the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
config: GCPArtifactStoreConfig
property
readonly
Returns the GCPArtifactStoreConfig
config.
Returns:
Type | Description |
---|---|
GCPArtifactStoreConfig |
The configuration. |
filesystem: GCSFileSystem
property
readonly
The gcsfs filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
GCSFileSystem |
The gcsfs filesystem to access this artifact store. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [
f"{GCP_PATH_PREFIX}{path}"
for path in self.filesystem.glob(path=pattern)
]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths of files in the directory. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path of the directory to list.
Returns:
A list of paths of files in the directory.
"""
path_without_prefix = convert_to_str(path)
if path_without_prefix.startswith(GCP_PATH_PREFIX):
path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by GCP.
Args:
file_dict: A file info dict returned by the GCP filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path_without_prefix) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to create. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to create. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object that can be used to read or write to the file. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object that can be used to read or write to the file.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the file to remove. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
the path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the stat info. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: the path to get stat info for.
Returns:
A dictionary with the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
constants
Constants for the VertexAI integration.
flavors
special
GCP integration flavors.
gcp_artifact_store_flavor
GCP artifact store flavor.
GCPArtifactStoreConfig (BaseArtifactStoreConfig, AuthenticationConfigMixin)
pydantic-model
Configuration for GCP Artifact Store.
Source code in zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py
class GCPArtifactStoreConfig(
BaseArtifactStoreConfig, AuthenticationConfigMixin
):
"""Configuration for GCP Artifact Store."""
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX}
GCPArtifactStoreFlavor (BaseArtifactStoreFlavor)
Flavor of the GCP artifact store.
Source code in zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py
class GCPArtifactStoreFlavor(BaseArtifactStoreFlavor):
"""Flavor of the GCP artifact store."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return GCP_ARTIFACT_STORE_FLAVOR
@property
def config_class(self) -> Type[GCPArtifactStoreConfig]:
"""Returns GCPArtifactStoreConfig config class.
Returns:
The config class.
"""
return GCPArtifactStoreConfig
@property
def implementation_class(self) -> Type["GCPArtifactStore"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.artifact_stores import GCPArtifactStore
return GCPArtifactStore
config_class: Type[zenml.integrations.gcp.flavors.gcp_artifact_store_flavor.GCPArtifactStoreConfig]
property
readonly
Returns GCPArtifactStoreConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.gcp_artifact_store_flavor.GCPArtifactStoreConfig] |
The config class. |
implementation_class: Type[GCPArtifactStore]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GCPArtifactStore] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
gcp_secrets_manager_flavor
GCP secrets manager flavor.
GCPSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the GCP Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
project_id |
str |
This is necessary to access the correct GCP project. The project_id of your GCP project space that contains the Secret Manager. |
Source code in zenml/integrations/gcp/flavors/gcp_secrets_manager_flavor.py
class GCPSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the GCP Secrets Manager.
Attributes:
project_id: This is necessary to access the correct GCP project.
The project_id of your GCP project space that contains the Secret
Manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
project_id: str
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_gcp_secret_name_or_namespace(namespace)
GCPSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the GCPSecretsManagerFlavor
.
Source code in zenml/integrations/gcp/flavors/gcp_secrets_manager_flavor.py
class GCPSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `GCPSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return GCP_SECRETS_MANAGER_FLAVOR
@property
def config_class(self) -> Type[GCPSecretsManagerConfig]:
"""Returns GCPSecretsManagerConfig config class.
Returns:
The config class.
"""
return GCPSecretsManagerConfig
@property
def implementation_class(self) -> Type["GCPSecretsManager"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.secrets_manager import GCPSecretsManager
return GCPSecretsManager
config_class: Type[zenml.integrations.gcp.flavors.gcp_secrets_manager_flavor.GCPSecretsManagerConfig]
property
readonly
Returns GCPSecretsManagerConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.gcp_secrets_manager_flavor.GCPSecretsManagerConfig] |
The config class. |
implementation_class: Type[GCPSecretsManager]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GCPSecretsManager] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
validate_gcp_secret_name_or_namespace(name)
Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters and can contain uppercase and lowercase letters, numerals, and the hyphen (-) and underscore (_) characters. For scoped secrets, we have to limit the size of the name and namespace even further to allow space for both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are also limited by the limitation that Google imposes on label values: max 63 characters and must only contain lowercase letters, numerals and the hyphen (-) and underscore (_) characters
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/gcp/flavors/gcp_secrets_manager_flavor.py
def validate_gcp_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters
and can contain uppercase and lowercase letters, numerals, and the
hyphen (-) and underscore (_) characters. For scoped secrets, we have to
limit the size of the name and namespace even further to allow space for
both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are
also limited by the limitation that Google imposes on label values: max
63 characters and must only contain lowercase letters, numerals
and the hyphen (-) and underscore (_) characters
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-z0-9_\-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only lowercase alphanumeric characters and the hyphen (-) and "
f"underscore (_) characters."
)
if name and len(name) > 63:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 63 characters."
)
vertex_orchestrator_flavor
Vertex orchestrator flavor.
VertexOrchestratorConfig (BaseOrchestratorConfig, GoogleCredentialsConfigMixin)
pydantic-model
Configuration for the Vertex orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
project |
Optional[str] |
GCP project name. If |
location |
str |
Name of GCP region where the pipeline job will be executed. Vertex AI Pipelines is available in the following regions: https://cloud.google.com/vertex-ai/docs/general/locations#feature -availability |
labels |
Dict[str, str] |
Labels to assign to the pipeline job. |
pipeline_root |
Optional[str] |
a Cloud Storage URI that will be used by the Vertex AI
Pipelines. If not provided but the artifact store in the stack used
to execute the pipeline is a
|
encryption_spec_key_name |
Optional[str] |
The Cloud KMS resource identifier of the
customer managed encryption key used to protect the job. Has the form:
|
workload_service_account |
Optional[str] |
the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If not provided, the default service account will be used. |
network |
Optional[str] |
the full name of the Compute Engine Network to which the job
should be peered. For example, |
synchronous |
bool |
If |
cpu_limit |
Optional[str] |
The maximum CPU limit for this operator. This string value can be a number (integer value for number of CPUs) as string, or a number followed by "m", which means 1/1000. You can specify at most 96 CPUs. (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types) |
memory_limit |
Optional[str] |
The maximum memory limit for this operator. This string value can be a number, or a number followed by "K" (kilobyte), "M" (megabyte), or "G" (gigabyte). At most 624GB is supported. |
node_selector_constraint |
Optional[Tuple[str, str]] |
Each constraint is a key-value pair label. For the container to be eligible to run on a node, the node must have each of the constraints appeared as labels. For example a GPU type can be providing by one of the following tuples: - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100") Hint: the selected region (location) must provide the requested accelerator (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones). |
gpu_limit |
Optional[int] |
The GPU limit (positive number) for the operator. For more information about GPU resources, see: https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus |
Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorConfig(
BaseOrchestratorConfig,
GoogleCredentialsConfigMixin,
):
"""Configuration for the Vertex orchestrator.
Attributes:
project: GCP project name. If `None`, the project will be inferred from
the environment.
location: Name of GCP region where the pipeline job will be executed.
Vertex AI Pipelines is available in the following regions:
https://cloud.google.com/vertex-ai/docs/general/locations#feature
-availability
labels: Labels to assign to the pipeline job.
pipeline_root: a Cloud Storage URI that will be used by the Vertex AI
Pipelines. If not provided but the artifact store in the stack used
to execute the pipeline is a
`zenml.integrations.gcp.artifact_stores.GCPArtifactStore`,
then a subdirectory of the artifact store will be used.
encryption_spec_key_name: The Cloud KMS resource identifier of the
customer managed encryption key used to protect the job. Has the form:
`projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY>`
. The key needs to be in the same region as where the compute
resource is created.
workload_service_account: the service account for workload run-as
account. Users submitting jobs must have act-as permission on this
run-as account.
If not provided, the default service account will be used.
network: the full name of the Compute Engine Network to which the job
should be peered. For example, `projects/12345/global/networks/myVPC`
If not provided, the job will not be peered with any network.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on Vertex AI Pipelines
service.
cpu_limit: The maximum CPU limit for this operator. This string value
can be a number (integer value for number of CPUs) as string,
or a number followed by "m", which means 1/1000. You can specify
at most 96 CPUs.
(see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)
memory_limit: The maximum memory limit for this operator. This string
value can be a number, or a number followed by "K" (kilobyte),
"M" (megabyte), or "G" (gigabyte). At most 624GB is supported.
node_selector_constraint: Each constraint is a key-value pair label.
For the container to be eligible to run on a node, the node must have
each of the constraints appeared as labels.
For example a GPU type can be providing by one of the following tuples:
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100")
Hint: the selected region (location) must provide the requested accelerator
(see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).
gpu_limit: The GPU limit (positive number) for the operator.
For more information about GPU resources, see:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
"""
project: Optional[str] = None
location: str
pipeline_root: Optional[str] = None
labels: Dict[str, str] = {}
encryption_spec_key_name: Optional[str] = None
workload_service_account: Optional[str] = None
network: Optional[str] = None
synchronous: bool = False
cpu_limit: Optional[str] = None
memory_limit: Optional[str] = None
node_selector_constraint: Optional[Tuple[str, str]] = None
gpu_limit: Optional[int] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
VertexOrchestratorFlavor (BaseOrchestratorFlavor)
Vertex Orchestrator flavor.
Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorFlavor(BaseOrchestratorFlavor):
"""Vertex Orchestrator flavor."""
@property
def name(self) -> str:
"""Name of the orchestrator flavor.
Returns:
Name of the orchestrator flavor.
"""
return GCP_VERTEX_ORCHESTRATOR_FLAVOR
@property
def config_class(self) -> Type[VertexOrchestratorConfig]:
"""Returns VertexOrchestratorConfig config class.
Returns:
The config class.
"""
return VertexOrchestratorConfig
@property
def implementation_class(self) -> Type["VertexOrchestrator"]:
"""Implementation class for this flavor.
Returns:
Implementation class for this flavor.
"""
from zenml.integrations.gcp.orchestrators import VertexOrchestrator
return VertexOrchestrator
config_class: Type[zenml.integrations.gcp.flavors.vertex_orchestrator_flavor.VertexOrchestratorConfig]
property
readonly
Returns VertexOrchestratorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.vertex_orchestrator_flavor.VertexOrchestratorConfig] |
The config class. |
implementation_class: Type[VertexOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[VertexOrchestrator] |
Implementation class for this flavor. |
name: str
property
readonly
Name of the orchestrator flavor.
Returns:
Type | Description |
---|---|
str |
Name of the orchestrator flavor. |
vertex_step_operator_flavor
Vertex step operator flavor.
VertexStepOperatorConfig (BaseStepOperatorConfig, GoogleCredentialsConfigMixin)
pydantic-model
Configuration for the Vertex step operator.
Attributes:
Name | Type | Description |
---|---|---|
region |
str |
Region name, e.g., |
project |
Optional[str] |
GCP project name. If left None, inferred from the environment. |
accelerator_type |
Optional[str] |
Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType |
accelerator_count |
int |
Defines number of accelerators to be used for the job. |
machine_type |
str |
Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types |
encryption_spec_key_name |
Optional[str] |
Encryption spec key name. |
Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorConfig(
BaseStepOperatorConfig,
GoogleCredentialsConfigMixin,
):
"""Configuration for the Vertex step operator.
Attributes:
region: Region name, e.g., `europe-west1`.
project: GCP project name. If left None, inferred from the
environment.
accelerator_type: Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType
accelerator_count: Defines number of accelerators to be
used for the job.
machine_type: Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
encryption_spec_key_name: Encryption spec key name.
"""
region: str
project: Optional[str] = None
accelerator_type: Optional[str] = None
accelerator_count: int = 0
machine_type: str = "n1-standard-4"
# customer managed encryption key resource name
# will be applied to all Vertex AI resources if set
encryption_spec_key_name: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
VertexStepOperatorFlavor (BaseStepOperatorFlavor)
Vertex Step Operator flavor.
Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorFlavor(BaseStepOperatorFlavor):
"""Vertex Step Operator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return GCP_VERTEX_STEP_OPERATOR_FLAVOR
@property
def config_class(self) -> Type[VertexStepOperatorConfig]:
"""Returns `VertexStepOperatorConfig` config class.
Returns:
The config class.
"""
return VertexStepOperatorConfig
@property
def implementation_class(self) -> Type["VertexStepOperator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.step_operators import VertexStepOperator
return VertexStepOperator
config_class: Type[zenml.integrations.gcp.flavors.vertex_step_operator_flavor.VertexStepOperatorConfig]
property
readonly
Returns VertexStepOperatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.vertex_step_operator_flavor.VertexStepOperatorConfig] |
The config class. |
implementation_class: Type[VertexStepOperator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[VertexStepOperator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
google_credentials_mixin
Implementation of the Google credentials mixin.
GoogleCredentialsConfigMixin (StackComponentConfig)
pydantic-model
Config mixin for Google Cloud Platform credentials.
Attributes:
Name | Type | Description |
---|---|---|
service_account_path |
Optional[str] |
path to the service account credentials file to be used for authentication. If not provided, the default credentials will be used. |
Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsConfigMixin(StackComponentConfig):
"""Config mixin for Google Cloud Platform credentials.
Attributes:
service_account_path: path to the service account credentials file to be
used for authentication. If not provided, the default credentials
will be used.
"""
service_account_path: Optional[str] = None
GoogleCredentialsMixin (StackComponent)
StackComponent mixin to get Google Cloud Platform credentials.
Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsMixin(StackComponent):
"""StackComponent mixin to get Google Cloud Platform credentials."""
@property
def config(self) -> GoogleCredentialsConfigMixin:
"""Returns the `GoogleCredentialsConfigMixin` config.
Returns:
The configuration.
"""
return cast(GoogleCredentialsConfigMixin, self._config)
def _get_authentication(self) -> Tuple["Credentials", str]:
"""Get GCP credentials and the project ID associated with the credentials.
If `service_account_path` is provided, then the credentials will be
loaded from the file at that path. Otherwise, the default credentials
will be used.
Returns:
A tuple containing the credentials and the project ID associated to
the credentials.
"""
if self.config.service_account_path:
credentials, project_id = load_credentials_from_file(
self.config.service_account_path
)
else:
credentials, project_id = default()
return credentials, project_id
config: GoogleCredentialsConfigMixin
property
readonly
Returns the GoogleCredentialsConfigMixin
config.
Returns:
Type | Description |
---|---|
GoogleCredentialsConfigMixin |
The configuration. |
orchestrators
special
Initialization for the VertexAI orchestrator.
vertex_entrypoint_configuration
Implementation of the VertexAI entrypoint configuration.
VertexEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on Vertex AI Pipelines.
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
class VertexEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on Vertex AI Pipelines."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the Vertex job id.
"""
return super().get_entrypoint_options() | {VERTEX_JOB_ID_OPTION}
@classmethod
def get_entrypoint_arguments(
cls,
**kwargs: Any,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the Vertex job id.
Returns:
The superclass arguments as well as arguments for the Vertex job id.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{VERTEX_JOB_ID_OPTION}",
kwargs[VERTEX_JOB_ID_OPTION],
]
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the Vertex AI Pipeline job id.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Vertex AI Pipeline job id.
"""
job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
return job_id
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs, must include the Vertex job id. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the Vertex job id. |
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
cls,
**kwargs: Any,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the Vertex job id.
Returns:
The superclass arguments as well as arguments for the Vertex job id.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{VERTEX_JOB_ID_OPTION}",
kwargs[VERTEX_JOB_ID_OPTION],
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the Vertex job id. |
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the Vertex job id.
"""
return super().get_entrypoint_options() | {VERTEX_JOB_ID_OPTION}
get_run_name(self, pipeline_name)
Returns the Vertex AI Pipeline job id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The Vertex AI Pipeline job id. |
Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the Vertex AI Pipeline job id.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Vertex AI Pipeline job id.
"""
job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
return job_id
vertex_orchestrator
Implementation of the VertexAI orchestrator.
VertexOrchestrator (BaseOrchestrator, GoogleCredentialsMixin)
Orchestrator responsible for running pipelines on Vertex AI.
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
class VertexOrchestrator(BaseOrchestrator, GoogleCredentialsMixin):
"""Orchestrator responsible for running pipelines on Vertex AI."""
_pipeline_root: str
@property
def config(self) -> VertexOrchestratorConfig:
"""Returns the `VertexOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(VertexOrchestratorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Also validates that the artifact store is not local.
Returns:
A StackValidator instance.
"""
def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
"""Validates that all the stack components are not local.
Args:
stack: The stack to validate.
Returns:
A tuple of (is_valid, error_message).
"""
# Validate that the container registry is not local.
container_registry = stack.container_registry
if container_registry and container_registry.config.is_local:
return False, (
f"The Vertex orchestrator does not support local "
f"container registries. You should replace the component '"
f"{container_registry.name}' "
f"{container_registry.type.value} to a remote one."
)
# Validate that the rest of the components are not local.
for stack_comp in stack.components.values():
# For Forward compatibility a list of components is returned,
# but only the first item is relevant for now
# TODO: [server] make sure the ComponentModel actually has
# a local_path property or implement similar check
local_path = stack_comp.local_path
if not local_path:
continue
return False, (
f"The '{stack_comp.name}' {stack_comp.type.value} is a "
f"local stack component. The Vertex AI Pipelines "
f"orchestrator requires that all the components in the "
f"stack used to execute the pipeline have to be not local, "
f"because there is no way for Vertex to connect to your "
f"local machine. You should use a flavor of "
f"{stack_comp.type.value} other than '"
f"{stack_comp.flavor}'."
)
# If the `pipeline_root` has not been defined in the orchestrator
# configuration, and the artifact store is not a GCP artifact store,
# then raise an error.
if (
not self.config.pipeline_root
and stack.artifact_store.flavor != GCP_ARTIFACT_STORE_FLAVOR
):
return False, (
f"The attribute `pipeline_root` has not been set and it "
f"cannot be generated using the path of the artifact store "
f"because it is not a "
f"`zenml.integrations.gcp.artifact_store.GCPArtifactStore`."
f" To solve this issue, set the `pipeline_root` attribute "
f"manually executing the following command: "
f"`zenml orchestrator update {stack.orchestrator.name} "
f'--pipeline_root="<Cloud Storage URI>"`.'
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_stack_requirements,
)
@property
def root_directory(self) -> str:
"""Returns path to the root directory for files for this orchestrator.
Returns:
The path to the root directory for all files concerning this
orchestrator.
"""
return os.path.join(
get_global_config_directory(), "vertex", str(self.id)
)
@property
def pipeline_directory(self) -> str:
"""Returns path to directory where kubeflow pipelines files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
def _configure_container_resources(
self,
container_op: dsl.ContainerOp,
resource_settings: "ResourceSettings",
) -> None:
"""Adds resource requirements to the container.
Args:
container_op: The kubeflow container operation to configure.
resource_settings: The resource settings to use for this
container.
"""
# Set optional CPU, RAM and GPU constraints for the pipeline
cpu_limit = resource_settings.cpu_count or self.config.cpu_limit
if cpu_limit is not None:
container_op = container_op.set_cpu_limit(str(cpu_limit))
memory_limit = (
resource_settings.memory[:-1]
if resource_settings.memory
else self.config.memory_limit
)
if memory_limit is not None:
container_op = container_op.set_memory_limit(memory_limit)
gpu_limit = resource_settings.gpu_count or self.config.gpu_limit
if gpu_limit is not None:
container_op = container_op.set_gpu_limit(gpu_limit)
if self.config.node_selector_constraint is not None:
constraint_label = self.config.node_selector_constraint[0]
value = self.config.node_selector_constraint[1]
if not (
constraint_label
== GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
and gpu_limit == 0
):
container_op.add_node_selector_constraint(
constraint_label, value
)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Creates a KFP JSON pipeline.
# noqa: DAR402
This is an intermediary representation of the pipeline which is then
deployed to Vertex AI Pipelines service.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()` method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`). The function
`kfp.components.load_component_from_text` is used to create the
`ContainerOp`, because using the `dsl.ContainerOp` class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the
intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for
execution.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
ValueError: If the attribute `pipeline_root` is not set and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
# If the `pipeline_root` has not been defined in the orchestrator
# configuration,
# try to create it from the artifact store if it is a
# `GCPArtifactStore`.
if not self.config.pipeline_root:
artifact_store = stack.artifact_store
self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{deployment.pipeline.name}/{deployment.run_name}"
logger.info(
"The attribute `pipeline_root` has not been set in the "
"orchestrator configuration. One has been generated "
"automatically based on the path of the `GCPArtifactStore` "
"artifact store in the stack used to execute the pipeline. "
"The generated `pipeline_root` is `%s`.",
self._pipeline_root,
)
else:
self._pipeline_root = self.config.pipeline_root
if deployment.schedule:
logger.warning(
"Pipeline scheduling configuration was provided, but Vertex "
"AI Pipelines does not support scheduling yet. Creating "
"a one-time run instead."
)
# Build the Docker image that will be used to run the steps of the
# pipeline.
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
def _construct_kfp_pipeline() -> None:
"""Create a `ContainerOp` for each step.
This should contain the name of the Docker image and configures the
entrypoint of the Docker image to run the step.
Additionally, this gives each `ContainerOp` information about its
direct downstream steps.
If this callable is passed to the `compile()` method of
`KFPV2Compiler` all `dsl.ContainerOp` instances will be
automatically added to a singular `dsl.Pipeline` instance.
"""
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.steps.items():
# The command will be needed to eventually call the python step
# within the docker container
command = VertexEntrypointConfiguration.get_entrypoint_command()
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
**{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
)
# Create the `ContainerOp` for the step. Using the
# `dsl.ContainerOp`
# class directly is deprecated when using the Kubeflow SDK v2.
container_op = kfp.components.load_component_from_text(
f"""
name: {step.config.name}
implementation:
container:
image: {image_name}
command: {command + arguments}"""
)()
# Set upstream tasks as a dependency of the current step
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
)
step_name_to_container_op[step.config.name] = container_op
# Save the generated pipeline to a file.
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory,
f"{deployment.run_name}.json",
)
# Compile the pipeline using the Kubeflow SDK V2 compiler that allows
# to generate a JSON representation of the pipeline that can be later
# upload to Vertex AI Pipelines service.
logger.debug(
"Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
"to `%s`",
pipeline_file_path,
)
KFPV2Compiler().compile(
pipeline_func=_construct_kfp_pipeline,
package_path=pipeline_file_path,
pipeline_name=_clean_pipeline_name(deployment.pipeline.name),
)
# Using the Google Cloud AIPlatform client, upload and execute the
# pipeline
# on the Vertex AI Pipelines service.
self._upload_and_run_pipeline(
pipeline_name=deployment.pipeline.name,
pipeline_file_path=pipeline_file_path,
run_name=deployment.run_name,
enable_cache=deployment.pipeline.enable_cache,
)
def _upload_and_run_pipeline(
self,
pipeline_name: str,
pipeline_file_path: str,
run_name: str,
enable_cache: bool,
) -> None:
"""Uploads and run the pipeline on the Vertex AI Pipelines service.
Args:
pipeline_name: Name of the pipeline.
pipeline_file_path: Path of the JSON file containing the compiled
Kubeflow pipeline (compiled with Kubeflow SDK v2).
run_name: Name of the pipeline run.
enable_cache: Whether caching is enabled for this pipeline run.
"""
# We have to replace the hyphens in the pipeline name with underscores
# and lower case the string, because the Vertex AI Pipelines service
# requires this format.
job_id = _clean_pipeline_name(run_name)
# Get the credentials that would be used to create the Vertex AI
# Pipelines
# job.
credentials, project_id = self._get_authentication()
if self.config.project and self.config.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this orchestrator is "
"configured to use the project `%s`.",
project_id,
self.config.project,
)
# If the project was set in the configuration, use it. Otherwise, use
# the project that was used to authenticate.
project_id = self.config.project if self.config.project else project_id
# Instantiate the Vertex AI Pipelines job
run = aiplatform.PipelineJob(
display_name=pipeline_name,
template_path=pipeline_file_path,
job_id=job_id,
pipeline_root=self._pipeline_root,
parameter_values=None,
enable_caching=enable_cache,
encryption_spec_key_name=self.config.encryption_spec_key_name,
labels=self.config.labels,
credentials=credentials,
project=self.config.project,
location=self.config.location,
)
logger.info(
"Submitting pipeline job with job_id `%s` to Vertex AI Pipelines "
"service.",
job_id,
)
# Submit the job to Vertex AI Pipelines service.
try:
if self.config.workload_service_account:
logger.info(
"The Vertex AI Pipelines job workload will be executed "
"using `%s` "
"service account.",
self.config.workload_service_account,
)
if self.config.network:
logger.info(
"The Vertex AI Pipelines job will be peered with `%s` "
"network.",
self.config.network,
)
run.submit(
service_account=self.config.workload_service_account,
network=self.config.network,
)
logger.info(
"View the Vertex AI Pipelines job at %s", run._dashboard_uri()
)
if self.config.synchronous:
logger.info(
"Waiting for the Vertex AI Pipelines job to finish..."
)
run.wait()
except google_exceptions.ClientError as e:
logger.warning(
"Failed to create the Vertex AI Pipelines job: %s", e
)
except RuntimeError as e:
logger.error(
"The Vertex AI Pipelines job execution has failed: %s", e
)
config: VertexOrchestratorConfig
property
readonly
Returns the VertexOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
VertexOrchestratorConfig |
The configuration. |
pipeline_directory: str
property
readonly
Returns path to directory where kubeflow pipelines files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Returns path to the root directory for files for this orchestrator.
Returns:
Type | Description |
---|---|
str |
The path to the root directory for all files concerning this orchestrator. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Also validates that the artifact store is not local.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A StackValidator instance. |
prepare_or_run_pipeline(self, deployment, stack)
Creates a KFP JSON pipeline.
noqa: DAR402
This is an intermediary representation of the pipeline which is then deployed to Vertex AI Pipelines service.
How it works:
Before this method is called the prepare_pipeline_deployment()
method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (_construct_kfp_pipeline
). The function
kfp.components.load_component_from_text
is used to create the
ContainerOp
, because using the dsl.ContainerOp
class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the attribute |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Creates a KFP JSON pipeline.
# noqa: DAR402
This is an intermediary representation of the pipeline which is then
deployed to Vertex AI Pipelines service.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()` method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`). The function
`kfp.components.load_component_from_text` is used to create the
`ContainerOp`, because using the `dsl.ContainerOp` class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the
intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for
execution.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
ValueError: If the attribute `pipeline_root` is not set and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
# If the `pipeline_root` has not been defined in the orchestrator
# configuration,
# try to create it from the artifact store if it is a
# `GCPArtifactStore`.
if not self.config.pipeline_root:
artifact_store = stack.artifact_store
self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{deployment.pipeline.name}/{deployment.run_name}"
logger.info(
"The attribute `pipeline_root` has not been set in the "
"orchestrator configuration. One has been generated "
"automatically based on the path of the `GCPArtifactStore` "
"artifact store in the stack used to execute the pipeline. "
"The generated `pipeline_root` is `%s`.",
self._pipeline_root,
)
else:
self._pipeline_root = self.config.pipeline_root
if deployment.schedule:
logger.warning(
"Pipeline scheduling configuration was provided, but Vertex "
"AI Pipelines does not support scheduling yet. Creating "
"a one-time run instead."
)
# Build the Docker image that will be used to run the steps of the
# pipeline.
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
def _construct_kfp_pipeline() -> None:
"""Create a `ContainerOp` for each step.
This should contain the name of the Docker image and configures the
entrypoint of the Docker image to run the step.
Additionally, this gives each `ContainerOp` information about its
direct downstream steps.
If this callable is passed to the `compile()` method of
`KFPV2Compiler` all `dsl.ContainerOp` instances will be
automatically added to a singular `dsl.Pipeline` instance.
"""
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.steps.items():
# The command will be needed to eventually call the python step
# within the docker container
command = VertexEntrypointConfiguration.get_entrypoint_command()
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
**{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
)
# Create the `ContainerOp` for the step. Using the
# `dsl.ContainerOp`
# class directly is deprecated when using the Kubeflow SDK v2.
container_op = kfp.components.load_component_from_text(
f"""
name: {step.config.name}
implementation:
container:
image: {image_name}
command: {command + arguments}"""
)()
# Set upstream tasks as a dependency of the current step
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
)
step_name_to_container_op[step.config.name] = container_op
# Save the generated pipeline to a file.
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory,
f"{deployment.run_name}.json",
)
# Compile the pipeline using the Kubeflow SDK V2 compiler that allows
# to generate a JSON representation of the pipeline that can be later
# upload to Vertex AI Pipelines service.
logger.debug(
"Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
"to `%s`",
pipeline_file_path,
)
KFPV2Compiler().compile(
pipeline_func=_construct_kfp_pipeline,
package_path=pipeline_file_path,
pipeline_name=_clean_pipeline_name(deployment.pipeline.name),
)
# Using the Google Cloud AIPlatform client, upload and execute the
# pipeline
# on the Vertex AI Pipelines service.
self._upload_and_run_pipeline(
pipeline_name=deployment.pipeline.name,
pipeline_file_path=pipeline_file_path,
run_name=deployment.run_name,
enable_cache=deployment.pipeline.enable_cache,
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
secrets_manager
special
ZenML integration for GCP Secrets Manager.
The GCP Secrets Manager allows your pipeline to directly access the GCP secrets manager and use the secrets within during runtime.
gcp_secrets_manager
Implementation of the GCP Secrets Manager.
GCPSecretsManager (BaseSecretsManager)
Class to interact with the GCP secrets manager.
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
class GCPSecretsManager(BaseSecretsManager):
"""Class to interact with the GCP secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> GCPSecretsManagerConfig:
"""Returns the `GCPSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(GCPSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls) -> None:
if cls.CLIENT is None:
cls.CLIENT = secretmanager.SecretManagerServiceClient()
@property
def parent_name(self) -> str:
"""Construct the GCP parent path to the secret manager.
Returns:
The parent path to the secret manager
"""
return f"projects/{self.config.project_id}"
def _convert_secret_content(
self, secret: BaseSecretSchema
) -> Dict[str, str]:
"""Convert the secret content into a Google compatible representation.
This method implements two currently supported modes of adapting between
the naming schemas used for ZenML secrets and Google secrets:
* for a scoped Secrets Manager, a Google secret is created for each
ZenML secret with a name that reflects the ZenML secret name and scope
and a value that contains all its key-value pairs in JSON format.
* for an unscoped (i.e. legacy) Secrets Manager, this method creates
multiple Google secret entries for a single ZenML secret by adding the
secret name to the key name of each secret key-value pair. This allows
using the same key across multiple secrets. This is only kept for
backwards compatibility and will be removed some time in the future.
Args:
secret: The ZenML secret
Returns:
A dictionary with the Google secret name as key and the secret
contents as value.
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy per-key secret mapping
return {f"{secret.name}_{k}": v for k, v in secret.content.items()}
return {
self._get_scoped_secret_name(
secret.name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
): json.dumps(secret_to_dict(secret)),
}
def _get_secret_labels(
self, secret: BaseSecretSchema
) -> List[Tuple[str, str]]:
"""Return a list of Google secret label values for a given secret.
Args:
secret: the secret object
Returns:
A list of Google secret label values
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy per-key secret labels
return [
(ZENML_GROUP_KEY, secret.name),
(ZENML_SCHEMA_NAME, secret.TYPE),
]
metadata = self._get_secret_metadata(secret)
return list(metadata.items())
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> str:
"""Return a Google filter expression for the entire scope or just a scoped secret.
These filters can be used when querying the Google Secrets Manager
for all secrets or for a single secret available in the configured
scope (see https://cloud.google.com/secret-manager/docs/filtering).
Args:
secret_name: Optional secret name to include in the scope metadata.
Returns:
Google filter expression uniquely identifying all secrets
or a named secret within the configured scope.
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy per-key secret label filters
if secret_name:
return f"labels.{ZENML_GROUP_KEY}={secret_name}"
else:
return f"labels.{ZENML_GROUP_KEY}:*"
metadata = self._get_secret_scope_metadata(secret_name)
filters = [f"labels.{l}={v}" for (l, v) in metadata.items()]
if secret_name:
filters.append(f"name:{secret_name}")
return " AND ".join(filters)
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected()
set_of_secrets = set()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
if self.config.scope == SecretsManagerScope.NONE:
name = secret.labels[ZENML_GROUP_KEY]
else:
name = secret.labels[ZENML_SECRET_NAME_LABEL]
# filter by secret name, if one was given
if name and (not secret_name or name == secret_name):
set_of_secrets.add(name)
return list(set_of_secrets)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
gcp_secret = self.CLIENT.create_secret(
request={
"parent": self.parent_name,
"secret_id": k,
"secret": {
"replication": {"automatic": {}},
"labels": self._get_secret_labels(secret),
},
}
)
logger.debug("Created empty secret: %s", gcp_secret.name)
self.CLIENT.add_secret_version(
request={
"parent": gcp_secret.name,
"payload": {"data": str(v).encode()},
}
)
logger.debug("Added value to secret.")
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
zenml_secret: Optional[BaseSecretSchema] = None
if self.config.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Google secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
# List all secrets.
for google_secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
secret_version_name = google_secret.name + "/versions/latest"
response = self.CLIENT.access_secret_version(
request={"name": secret_version_name}
)
secret_value = response.payload.data.decode("UTF-8")
secret_key = remove_group_name_from_key(
google_secret.name.split("/")[-1], secret_name
)
secret_contents[secret_key] = secret_value
zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]
if not secret_contents:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Google secrets
google_secret_name = self.CLIENT.secret_path(
self.config.project_id,
self._get_scoped_secret_name(
secret_name,
separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR,
),
)
try:
# fetch the latest secret version
google_secret = self.CLIENT.get_secret(name=google_secret_name)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
# make sure the secret has the correct scope labels to filter out
# unscoped secrets with similar names
scope_labels = self._get_secret_scope_metadata(secret_name)
# all scope labels need to be included in the google secret labels,
# otherwise the secret does not belong to the current scope
if not scope_labels.items() <= google_secret.labels.items():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
try:
# fetch the latest secret version
response = self.CLIENT.access_secret_version(
name=f"{google_secret_name}/versions/latest"
)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_value = response.payload.data.decode("UTF-8")
zenml_secret = secret_from_dict(
json.loads(secret_value), secret_name=secret_name
)
return zenml_secret
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
google_secret_name = self.CLIENT.secret_path(
self.config.project_id, k
)
payload = {"data": str(v).encode()}
self.CLIENT.add_secret_version(
request={"parent": google_secret_name, "payload": payload}
)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
# Go through all gcp secrets and delete the ones with the secret_name
# as label.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
self.CLIENT.delete_secret(request={"name": secret.name})
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(),
}
):
logger.info(f"Deleting Google secret {secret.name}")
self.CLIENT.delete_secret(request={"name": secret.name})
config: GCPSecretsManagerConfig
property
readonly
Returns the GCPSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
GCPSecretsManagerConfig |
The configuration. |
parent_name: str
property
readonly
Construct the GCP parent path to the secret manager.
Returns:
Type | Description |
---|---|
str |
The parent path to the secret manager |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(),
}
):
logger.info(f"Deleting Google secret {secret.name}")
self.CLIENT.delete_secret(request={"name": secret.name})
delete_secret(self, secret_name)
Delete an existing secret by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret no longer exists |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
# Go through all gcp secrets and delete the ones with the secret_name
# as label.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
self.CLIENT.delete_secret(request={"name": secret.name})
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Get a secret by its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
zenml_secret: Optional[BaseSecretSchema] = None
if self.config.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Google secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
# List all secrets.
for google_secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
secret_version_name = google_secret.name + "/versions/latest"
response = self.CLIENT.access_secret_version(
request={"name": secret_version_name}
)
secret_value = response.payload.data.decode("UTF-8")
secret_key = remove_group_name_from_key(
google_secret.name.split("/")[-1], secret_name
)
secret_contents[secret_key] = secret_value
zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]
if not secret_contents:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Google secrets
google_secret_name = self.CLIENT.secret_path(
self.config.project_id,
self._get_scoped_secret_name(
secret_name,
separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR,
),
)
try:
# fetch the latest secret version
google_secret = self.CLIENT.get_secret(name=google_secret_name)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
# make sure the secret has the correct scope labels to filter out
# unscoped secrets with similar names
scope_labels = self._get_secret_scope_metadata(secret_name)
# all scope labels need to be included in the google secret labels,
# otherwise the secret does not belong to the current scope
if not scope_labels.items() <= google_secret.labels.items():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
try:
# fetch the latest secret version
response = self.CLIENT.access_secret_version(
name=f"{google_secret_name}/versions/latest"
)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_value = response.payload.data.decode("UTF-8")
zenml_secret = secret_from_dict(
json.loads(secret_value), secret_name=secret_name
)
return zenml_secret
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
gcp_secret = self.CLIENT.create_secret(
request={
"parent": self.parent_name,
"secret_id": k,
"secret": {
"replication": {"automatic": {}},
"labels": self._get_secret_labels(secret),
},
}
)
logger.debug("Created empty secret: %s", gcp_secret.name)
self.CLIENT.add_secret_version(
request={
"parent": gcp_secret.name,
"payload": {"data": str(v).encode()},
}
)
logger.debug("Added value to secret.")
update_secret(self, secret)
Update an existing secret by creating new versions of the existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
google_secret_name = self.CLIENT.secret_path(
self.config.project_id, k
)
payload = {"data": str(v).encode()}
self.CLIENT.add_secret_version(
request={"parent": google_secret_name, "payload": payload}
)
remove_group_name_from_key(combined_key_name, group_name)
Removes the secret group name from the secret key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
combined_key_name |
str |
Full name as it is within the gcp secrets manager |
required |
group_name |
str |
Group name (the ZenML Secret name) |
required |
Returns:
Type | Description |
---|---|
str |
The cleaned key |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the group name is not found in the key |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
"""Removes the secret group name from the secret key.
Args:
combined_key_name: Full name as it is within the gcp secrets manager
group_name: Group name (the ZenML Secret name)
Returns:
The cleaned key
Raises:
RuntimeError: If the group name is not found in the key
"""
if combined_key_name.startswith(group_name + "_"):
return combined_key_name[len(group_name + "_") :]
else:
raise RuntimeError(
f"Key-name `{combined_key_name}` does not have the "
f"prefix `{group_name}`. Key could not be "
f"extracted."
)
step_operators
special
Initialization for the VertexAI Step Operator.
vertex_step_operator
Implementation of a VertexAI step operator.
Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py
VertexStepOperator (BaseStepOperator, GoogleCredentialsMixin)
Step operator to run a step on Vertex AI.
This class defines code that can set up a Vertex AI environment and run the ZenML entrypoint command in it.
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
"""Step operator to run a step on Vertex AI.
This class defines code that can set up a Vertex AI environment and run the
ZenML entrypoint command in it.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes the step operator and validates the accelerator type.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._validate_accelerator_type()
def _validate_accelerator_type(self) -> None:
"""Validates that the accelerator type is valid.
Raises:
ValueError: If the accelerator type is not valid.
"""
accepted_vals = list(
aiplatform.gapic.AcceleratorType.__members__.keys()
)
accelerator_type = self.config.accelerator_type
if accelerator_type and accelerator_type.upper() not in accepted_vals:
raise ValueError(
f"Accelerator must be one of the following: {accepted_vals}"
)
@property
def config(self) -> VertexStepOperatorConfig:
"""Returns the `VertexStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(VertexStepOperatorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The Vertex step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the Vertex "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The Vertex step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"Vertex step operator."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_remote_components,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
docker_image_builder = PipelineDockerImageBuilder()
image_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment,
stack=stack,
)
for step in steps_to_run:
step.config.extra[VERTEX_DOCKER_IMAGE_DIGEST_KEY] = image_digest
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on VertexAI.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
resource_settings = info.config.resource_settings
if resource_settings.cpu_count or resource_settings.memory:
logger.warning(
"Specifying cpus or memory is not supported for "
"the Vertex step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different machine_type type like this: "
"`zenml step-operator update %s "
"--machine_type=<MACHINE_TYPE>`",
self.name,
)
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
if self.config.project:
if self.config.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this orchestrator is "
"configured to use the project `%s`.",
project_id,
self.config.project,
)
else:
self.config.project = project_id
image_name = info.config.extra[VERTEX_DOCKER_IMAGE_DIGEST_KEY]
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
accelerator_count = (
resource_settings.gpu_count or self.config.accelerator_count
)
custom_job = {
"display_name": info.run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": self.config.machine_type,
"accelerator_type": self.config.accelerator_type,
"accelerator_count": accelerator_count
if self.config.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": entrypoint_command,
"args": [],
},
}
]
},
"labels": job_labels,
"encryption_spec": {
"kmsKeyName": self.config.encryption_spec_key_name
}
if self.config.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = (
f"projects/{self.config.project}/locations/{self.config.region}"
)
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state
# periodically, and retries the polling when a transient connectivity
# issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELED are considered to be final states.
# The following logic will keep polling the state of the job until
# the job enters a final state.
#
# During the polling, if a connection error was encountered, the GET
# request will be retried by recreating the Python API client to
# refresh the lifecycle of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
# will raise ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
config: VertexStepOperatorConfig
property
readonly
Returns the VertexStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
VertexStepOperatorConfig |
The configuration. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
__init__(self, *args, **kwargs)
special
Initializes the step operator and validates the accelerator type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Variable length argument list. |
() |
**kwargs |
Any |
Arbitrary keyword arguments. |
{} |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes the step operator and validates the accelerator type.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._validate_accelerator_type()
launch(self, info, entrypoint_command)
Launches a step on VertexAI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run fails. |
ConnectionError |
If the run fails due to a connection error. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on VertexAI.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
resource_settings = info.config.resource_settings
if resource_settings.cpu_count or resource_settings.memory:
logger.warning(
"Specifying cpus or memory is not supported for "
"the Vertex step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different machine_type type like this: "
"`zenml step-operator update %s "
"--machine_type=<MACHINE_TYPE>`",
self.name,
)
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
if self.config.project:
if self.config.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this orchestrator is "
"configured to use the project `%s`.",
project_id,
self.config.project,
)
else:
self.config.project = project_id
image_name = info.config.extra[VERTEX_DOCKER_IMAGE_DIGEST_KEY]
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
accelerator_count = (
resource_settings.gpu_count or self.config.accelerator_count
)
custom_job = {
"display_name": info.run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": self.config.machine_type,
"accelerator_type": self.config.accelerator_type,
"accelerator_count": accelerator_count
if self.config.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": entrypoint_command,
"args": [],
},
}
]
},
"labels": job_labels,
"encryption_spec": {
"kmsKeyName": self.config.encryption_spec_key_name
}
if self.config.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = (
f"projects/{self.config.project}/locations/{self.config.region}"
)
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state
# periodically, and retries the polling when a transient connectivity
# issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELED are considered to be final states.
# The following logic will keep polling the state of the job until
# the job enters a final state.
#
# During the polling, if a connection error was encountered, the GET
# request will be retried by recreating the Python API client to
# refresh the lifecycle of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
# will raise ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
docker_image_builder = PipelineDockerImageBuilder()
image_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment,
stack=stack,
)
for step in steps_to_run:
step.config.extra[VERTEX_DOCKER_IMAGE_DIGEST_KEY] = image_digest
github
special
Initialization of the GitHub ZenML integration.
The GitHub integration provides a way to orchestrate pipelines using GitHub Actions.
GitHubIntegration (Integration)
Definition of GitHub integration for ZenML.
Source code in zenml/integrations/github/__init__.py
class GitHubIntegration(Integration):
"""Definition of GitHub integration for ZenML."""
NAME = GITHUB
REQUIREMENTS: List[str] = ["PyNaCl~=1.5.0"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the GitHub integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.github.flavors import (
GitHubActionsOrchestratorFlavor,
GitHubSecretsManagerFlavor,
)
return [GitHubActionsOrchestratorFlavor, GitHubSecretsManagerFlavor]
flavors()
classmethod
Declare the stack component flavors for the GitHub integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/github/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the GitHub integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.github.flavors import (
GitHubActionsOrchestratorFlavor,
GitHubSecretsManagerFlavor,
)
return [GitHubActionsOrchestratorFlavor, GitHubSecretsManagerFlavor]
flavors
special
GitHub integration flavors.
github_actions_orchestrator_flavor
GitHub Actions orchestrator flavor.
GitHubActionsOrchestratorConfig (BaseOrchestratorConfig)
pydantic-model
Configuration for the GitHub Actions orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
skip_dirty_repository_check |
bool |
If |
skip_github_repository_check |
bool |
If |
push |
bool |
If |
Source code in zenml/integrations/github/flavors/github_actions_orchestrator_flavor.py
class GitHubActionsOrchestratorConfig(BaseOrchestratorConfig):
"""Configuration for the GitHub Actions orchestrator.
Attributes:
skip_dirty_repository_check: If `True`, this orchestrator will not
raise an exception when trying to run a pipeline while there are
still untracked/uncommitted files in the git repository.
skip_github_repository_check: If `True`, the orchestrator will not check
if your git repository is pointing to a GitHub remote.
push: If `True`, this orchestrator will automatically commit and push
the GitHub workflow file when running a pipeline. If `False`, the
workflow file will be written to the correct location but needs to
be committed and pushed manually.
"""
skip_dirty_repository_check: bool = False
skip_github_repository_check: bool = False
push: bool = False
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
GitHubActionsOrchestratorFlavor (BaseOrchestratorFlavor)
GitHub Actions orchestrator flavor.
Source code in zenml/integrations/github/flavors/github_actions_orchestrator_flavor.py
class GitHubActionsOrchestratorFlavor(BaseOrchestratorFlavor):
"""GitHub Actions orchestrator flavor."""
@property
def name(self) -> str:
"""Name of the orchestrator flavor.
Returns:
Name of the orchestrator flavor.
"""
return GITHUB_ORCHESTRATOR_FLAVOR
@property
def config_class(self) -> Type[GitHubActionsOrchestratorConfig]:
"""Returns `GitHubActionsOrchestratorConfig` config class.
Returns:
The config class.
"""
return GitHubActionsOrchestratorConfig
@property
def implementation_class(self) -> Type["GitHubActionsOrchestrator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.github.orchestrators import (
GitHubActionsOrchestrator,
)
return GitHubActionsOrchestrator
config_class: Type[zenml.integrations.github.flavors.github_actions_orchestrator_flavor.GitHubActionsOrchestratorConfig]
property
readonly
Returns GitHubActionsOrchestratorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.github.flavors.github_actions_orchestrator_flavor.GitHubActionsOrchestratorConfig] |
The config class. |
implementation_class: Type[GitHubActionsOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GitHubActionsOrchestrator] |
The implementation class. |
name: str
property
readonly
Name of the orchestrator flavor.
Returns:
Type | Description |
---|---|
str |
Name of the orchestrator flavor. |
github_secrets_manager_flavor
GitHub secrets manager flavor.
GitHubSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
The configuration for the GitHub Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
owner |
str |
The owner (either individual or organization) of the repository. |
repository |
str |
Name of the GitHub repository. |
Source code in zenml/integrations/github/flavors/github_secrets_manager_flavor.py
class GitHubSecretsManagerConfig(BaseSecretsManagerConfig):
"""The configuration for the GitHub Secrets Manager.
Attributes:
owner: The owner (either individual or organization) of the repository.
repository: Name of the GitHub repository.
"""
owner: str
repository: str
GitHubSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the GitHubSecretsManagerFlavor
.
Source code in zenml/integrations/github/flavors/github_secrets_manager_flavor.py
class GitHubSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `GitHubSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return GITHUB_SECRET_MANAGER_FLAVOR
@property
def config_class(self) -> Type[GitHubSecretsManagerConfig]:
"""Returns `GitHubSecretsManagerConfig` config class.
Returns:
The config class.
"""
return GitHubSecretsManagerConfig
@property
def implementation_class(self) -> Type["GitHubSecretsManager"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.github.secrets_managers import (
GitHubSecretsManager,
)
return GitHubSecretsManager
config_class: Type[zenml.integrations.github.flavors.github_secrets_manager_flavor.GitHubSecretsManagerConfig]
property
readonly
Returns GitHubSecretsManagerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.github.flavors.github_secrets_manager_flavor.GitHubSecretsManagerConfig] |
The config class. |
implementation_class: Type[GitHubSecretsManager]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GitHubSecretsManager] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
orchestrators
special
Initialization of the GitHub Actions Orchestrator.
github_actions_entrypoint_configuration
Implementation of the GitHub Actions Orchestrator entrypoint.
GitHubActionsEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on GitHub Actions runners.
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
class GitHubActionsEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on GitHub Actions runners."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the run id.
"""
return super().get_entrypoint_options() | {RUN_ID_OPTION}
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs.
Returns:
The superclass arguments as well as arguments for the run id.
"""
# These placeholders in the workflow file will be replaced with
# concrete values by the GitHub Actions runner
run_id = (
"${{ github.run_id }}_${{ github.run_number }}_"
"${{ github.run_attempt }}"
)
return super().get_entrypoint_arguments(**kwargs) + [
f"--{RUN_ID_OPTION}",
run_id,
]
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the pipeline run name.
Args:
pipeline_name: Name of the pipeline which will run.
Returns:
The run name.
"""
run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
return f"{pipeline_name}-{run_id}"
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the run id. |
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs.
Returns:
The superclass arguments as well as arguments for the run id.
"""
# These placeholders in the workflow file will be replaced with
# concrete values by the GitHub Actions runner
run_id = (
"${{ github.run_id }}_${{ github.run_number }}_"
"${{ github.run_attempt }}"
)
return super().get_entrypoint_arguments(**kwargs) + [
f"--{RUN_ID_OPTION}",
run_id,
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the run id. |
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the run id.
"""
return super().get_entrypoint_options() | {RUN_ID_OPTION}
get_run_name(self, pipeline_name)
Returns the pipeline run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which will run. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The run name. |
Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the pipeline run name.
Args:
pipeline_name: Name of the pipeline which will run.
Returns:
The run name.
"""
run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
return f"{pipeline_name}-{run_id}"
github_actions_orchestrator
Implementation of the GitHub Actions Orchestrator.
GitHubActionsOrchestrator (BaseOrchestrator)
Orchestrator responsible for running pipelines using GitHub Actions.
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
class GitHubActionsOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using GitHub Actions."""
_git_repo: Optional[Repo] = None
@property
def config(self) -> GitHubActionsOrchestratorConfig:
"""Returns the `GitHubActionsOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(GitHubActionsOrchestratorConfig, self._config)
@property
def git_repo(self) -> Repo:
"""Returns the git repository for the current working directory.
Returns:
Git repository for the current working directory.
Raises:
RuntimeError: If there is no git repository for the current working
directory or the repository remote is not pointing to GitHub.
"""
if not self._git_repo:
try:
self._git_repo = Repo(search_parent_directories=True)
except InvalidGitRepositoryError:
raise RuntimeError(
"Unable to find git repository in current working "
f"directory {os.getcwd()} or its parent directories."
)
remote_url = self.git_repo.remote().url
is_github_repo = any(
remote_url.startswith(prefix)
for prefix in GITHUB_REMOTE_URL_PREFIXES
)
if not (is_github_repo or self.config.skip_github_repository_check):
raise RuntimeError(
f"The remote URL '{remote_url}' of your git repo "
f"({self._git_repo.git_dir}) is not pointing to a GitHub "
"repository. The GitHub Actions orchestrator runs "
"pipelines using GitHub Actions and therefore only works "
"with GitHub repositories. If you want to skip this check "
"and run this orchestrator anyway, run: \n"
f"`zenml orchestrator update {self.name} "
"--skip_github_repository_check=true`"
)
return self._git_repo
@property
def workflow_directory(self) -> str:
"""Returns path to the GitHub workflows directory.
Returns:
The GitHub workflows directory.
"""
assert self.git_repo.working_dir
return os.path.join(self.git_repo.working_dir, ".github", "workflows")
@property
def validator(self) -> Optional[StackValidator]:
"""Validator that ensures that the stack is compatible.
Makes sure that the stack contains a container registry and only
remote components.
Returns:
The stack validator.
"""
def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The GitHub Actions orchestrator requires a remote "
f"container registry, but the '{container_registry.name}' "
"container registry of your active stack points to a local "
f"URI '{container_registry.config.uri}'. Please make sure "
"stacks with a GitHub Actions orchestrator always contain "
"remote container registries."
)
if container_registry.requires_authentication:
return False, (
"The GitHub Actions orchestrator currently only works with "
"GitHub container registries or public container "
f"registries, but your {container_registry.flavor} "
f"container registry '{container_registry.name}' requires "
"authentication."
)
for component in stack.components.values():
if component.local_path is not None:
return False, (
"The GitHub Actions orchestrator runs pipelines on "
"remote GitHub Actions runners, but the "
f"'{component.name}' {component.type.value} of your "
"active stack is a local component. Please make sure "
"to only use remote stack components in combination "
"with the GitHub Actions orchestrator. "
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_local_requirements,
)
def _docker_login_step(
self,
container_registry: BaseContainerRegistry,
) -> Optional[Dict[str, Any]]:
"""GitHub Actions step for authenticating with the container registry.
Args:
container_registry: The container registry which (potentially)
requires a step to authenticate.
Returns:
Dictionary specifying the GitHub Actions step for authenticating
with the container registry if that is required, `None` otherwise.
"""
if (
isinstance(container_registry, GitHubContainerRegistryFlavor)
and container_registry.config.automatic_token_authentication
):
# Use GitHub Actions specific placeholder if the container registry
# specifies automatic token authentication
username = "${{ github.actor }}"
password = "${{ secrets.GITHUB_TOKEN }}"
# TODO: Uncomment these lines once we support different private
# container registries in GitHub Actions
# elif container_registry.requires_authentication:
# username = cast(str, container_registry.username)
# password = cast(str, container_registry.password)
else:
return None
return {
"name": "Authenticate with the container registry",
"uses": DOCKER_LOGIN_ACTION,
"with": {
"registry": container_registry.uri,
"username": username,
"password": password,
},
}
def _write_environment_file_step(
self,
file_name: str,
secrets_manager: Optional[BaseSecretsManager] = None,
) -> Optional[Dict[str, Any]]:
"""GitHub Actions step for writing secrets to an environment file.
Args:
file_name: Name of the environment file that should be written.
secrets_manager: Secrets manager that will be used to read secrets
during pipeline execution.
Returns:
Dictionary specifying the GitHub Actions step for writing the
environment file.
"""
if not isinstance(secrets_manager, GitHubSecretsManager):
return None
# Always include the environment variable that specifies whether
# we're running in a GitHub Action workflow so the secret manager knows
# how to query secret values
command = (
f'echo {ENV_IN_GITHUB_ACTIONS}="${ENV_IN_GITHUB_ACTIONS}" '
f"> {file_name}; "
)
# Write all ZenML secrets into the environment file. Explicitly writing
# these `${{ secrets.<SECRET_NAME> }}` placeholders into the workflow
# yaml is the only way for us to access the GitHub secrets in a GitHub
# Actions workflow.
append_secret_placeholder = (
"echo {secret_name}=${{{{ secrets.{secret_name} }}}} >> {file}; "
)
for secret_name in secrets_manager.get_all_secret_keys(
include_prefix=True
):
command += append_secret_placeholder.format(
secret_name=secret_name, file=file_name
)
return {
"name": "Write environment file",
"run": command,
}
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
RuntimeError: If the orchestrator should only run in a clean git
repository and the repository is dirty.
"""
if (
not self.config.skip_dirty_repository_check
and self.git_repo.is_dirty(untracked_files=True)
):
raise RuntimeError(
"Trying to run a pipeline from within a dirty (=containing "
"untracked/uncommitted files) git repository."
"If you want this orchestrator to skip the dirty repo check in "
f"the future, run\n `zenml orchestrator update {self.name} "
"--skip_dirty_repository_check=true`"
)
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Writes a GitHub Action workflow yaml and optionally pushes it.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
ValueError: If a schedule without a cron expression or with an
invalid cron expression is passed.
"""
schedule = deployment.schedule
workflow_name = deployment.pipeline.name
if schedule:
# Add a suffix to the workflow filename so we don't overwrite
# scheduled pipeline by future schedules or single pipeline runs.
datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
workflow_name += f"-scheduled-{datetime_string}"
workflow_path = os.path.join(
self.workflow_directory,
f"{workflow_name}.yaml",
)
workflow_dict: Dict[str, Any] = {
"name": workflow_name,
}
if schedule:
if not schedule.cron_expression:
raise ValueError(
"GitHub Action workflows can only be scheduled using cron "
"expressions and not using a periodic schedule. If you "
"want to schedule pipelines using this GitHub Action "
"orchestrator, please include a cron expression in your "
"schedule object. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
# GitHub workflows requires a schedule interval of at least 5
# minutes. Invalid cron expressions would be something like
# `*/3 * * * *` (all stars except the first part of the expression,
# which will have the format `*/minute_interval`)
if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
raise ValueError(
"GitHub workflows requires a schedule interval of at "
"least 5 minutes which is incompatible with your cron "
f"expression '{schedule.cron_expression}'. An example of a "
"valid cron expression would be '* 1 * * *' to run "
"every hour. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
logger.warning(
"GitHub only runs scheduled workflows once the "
"workflow file is merged to the default branch of the "
"repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
"Please make sure to merge your current branch into the "
"default branch for this scheduled pipeline to run."
)
workflow_dict["on"] = {
"schedule": [{"cron": schedule.cron_expression}]
}
else:
# The pipeline should only run once. The only fool-proof way to
# only execute a workflow once seems to be running on specific tags.
# We don't want to create tags for each pipeline run though, so
# instead we only run this workflow if the workflow file is
# modified. As long as users don't manually modify these files this
# should be sufficient.
workflow_path_in_repo = os.path.relpath(
workflow_path, self.git_repo.working_dir
)
workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
# Prepare the step that writes an environment file which will get
# passed to the docker image
env_file_name = ".zenml_docker_env"
write_env_file_step = self._write_environment_file_step(
file_name=env_file_name, secrets_manager=stack.secrets_manager
)
docker_run_args = (
["--env-file", env_file_name] if write_env_file_step else []
)
# Prepare the docker login step if necessary
container_registry = stack.container_registry
assert container_registry
docker_login_step = self._docker_login_step(container_registry)
# The base command that each job will execute with specific arguments
base_command = [
"docker",
"run",
*docker_run_args,
image_name,
] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()
jobs = {}
for step_name, step in deployment.steps.items():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the "
"GitHub Actions orchestrator, ignoring resource "
"configuration for step %s.",
step.config.name,
)
job_steps = []
# Copy the shared dicts here to avoid creating yaml anchors (which
# are currently not supported in GitHub workflow yaml files)
if write_env_file_step:
job_steps.append(copy.deepcopy(write_env_file_step))
if docker_login_step:
job_steps.append(copy.deepcopy(docker_login_step))
entrypoint_args = (
GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
)
)
command = base_command + entrypoint_args
docker_run_step = {
"name": "Run the docker image",
"run": " ".join(command),
}
job_steps.append(docker_run_step)
job_dict = {
"runs-on": "ubuntu-latest",
"needs": step.spec.upstream_steps,
"steps": job_steps,
}
jobs[step.config.name] = job_dict
workflow_dict["jobs"] = jobs
fileio.makedirs(self.workflow_directory)
yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
logger.info("Wrote GitHub workflow file to %s", workflow_path)
if self.config.push:
# Add, commit and push the pipeline workflow yaml
self.git_repo.index.add(workflow_path)
self.git_repo.index.commit(
"[ZenML GitHub Actions Orchestrator] Add github workflow for "
f"pipeline {deployment.pipeline.name}."
)
self.git_repo.remote().push()
logger.info("Pushed workflow file '%s'", workflow_path)
else:
logger.info(
"Automatically committing and pushing is disabled for this "
"orchestrator. To run the pipeline, you'll have to commit and "
"push the workflow file '%s' manually.\n"
"If you want to update this orchestrator to automatically "
"commit and push in the future, run "
"`zenml orchestrator update %s --push=true`",
workflow_path,
self.name,
)
config: GitHubActionsOrchestratorConfig
property
readonly
Returns the GitHubActionsOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
GitHubActionsOrchestratorConfig |
The configuration. |
git_repo: Repo
property
readonly
Returns the git repository for the current working directory.
Returns:
Type | Description |
---|---|
Repo |
Git repository for the current working directory. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If there is no git repository for the current working directory or the repository remote is not pointing to GitHub. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validator that ensures that the stack is compatible.
Makes sure that the stack contains a container registry and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
The stack validator. |
workflow_directory: str
property
readonly
Returns path to the GitHub workflows directory.
Returns:
Type | Description |
---|---|
str |
The GitHub workflows directory. |
prepare_or_run_pipeline(self, deployment, stack)
Writes a GitHub Action workflow yaml and optionally pushes it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If a schedule without a cron expression or with an invalid cron expression is passed. |
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Writes a GitHub Action workflow yaml and optionally pushes it.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
ValueError: If a schedule without a cron expression or with an
invalid cron expression is passed.
"""
schedule = deployment.schedule
workflow_name = deployment.pipeline.name
if schedule:
# Add a suffix to the workflow filename so we don't overwrite
# scheduled pipeline by future schedules or single pipeline runs.
datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
workflow_name += f"-scheduled-{datetime_string}"
workflow_path = os.path.join(
self.workflow_directory,
f"{workflow_name}.yaml",
)
workflow_dict: Dict[str, Any] = {
"name": workflow_name,
}
if schedule:
if not schedule.cron_expression:
raise ValueError(
"GitHub Action workflows can only be scheduled using cron "
"expressions and not using a periodic schedule. If you "
"want to schedule pipelines using this GitHub Action "
"orchestrator, please include a cron expression in your "
"schedule object. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
# GitHub workflows requires a schedule interval of at least 5
# minutes. Invalid cron expressions would be something like
# `*/3 * * * *` (all stars except the first part of the expression,
# which will have the format `*/minute_interval`)
if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
raise ValueError(
"GitHub workflows requires a schedule interval of at "
"least 5 minutes which is incompatible with your cron "
f"expression '{schedule.cron_expression}'. An example of a "
"valid cron expression would be '* 1 * * *' to run "
"every hour. For more information on GitHub workflow "
"schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
)
logger.warning(
"GitHub only runs scheduled workflows once the "
"workflow file is merged to the default branch of the "
"repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
"Please make sure to merge your current branch into the "
"default branch for this scheduled pipeline to run."
)
workflow_dict["on"] = {
"schedule": [{"cron": schedule.cron_expression}]
}
else:
# The pipeline should only run once. The only fool-proof way to
# only execute a workflow once seems to be running on specific tags.
# We don't want to create tags for each pipeline run though, so
# instead we only run this workflow if the workflow file is
# modified. As long as users don't manually modify these files this
# should be sufficient.
workflow_path_in_repo = os.path.relpath(
workflow_path, self.git_repo.working_dir
)
workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
# Prepare the step that writes an environment file which will get
# passed to the docker image
env_file_name = ".zenml_docker_env"
write_env_file_step = self._write_environment_file_step(
file_name=env_file_name, secrets_manager=stack.secrets_manager
)
docker_run_args = (
["--env-file", env_file_name] if write_env_file_step else []
)
# Prepare the docker login step if necessary
container_registry = stack.container_registry
assert container_registry
docker_login_step = self._docker_login_step(container_registry)
# The base command that each job will execute with specific arguments
base_command = [
"docker",
"run",
*docker_run_args,
image_name,
] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()
jobs = {}
for step_name, step in deployment.steps.items():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the "
"GitHub Actions orchestrator, ignoring resource "
"configuration for step %s.",
step.config.name,
)
job_steps = []
# Copy the shared dicts here to avoid creating yaml anchors (which
# are currently not supported in GitHub workflow yaml files)
if write_env_file_step:
job_steps.append(copy.deepcopy(write_env_file_step))
if docker_login_step:
job_steps.append(copy.deepcopy(docker_login_step))
entrypoint_args = (
GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
)
)
command = base_command + entrypoint_args
docker_run_step = {
"name": "Run the docker image",
"run": " ".join(command),
}
job_steps.append(docker_run_step)
job_dict = {
"runs-on": "ubuntu-latest",
"needs": step.spec.upstream_steps,
"steps": job_steps,
}
jobs[step.config.name] = job_dict
workflow_dict["jobs"] = jobs
fileio.makedirs(self.workflow_directory)
yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
logger.info("Wrote GitHub workflow file to %s", workflow_path)
if self.config.push:
# Add, commit and push the pipeline workflow yaml
self.git_repo.index.add(workflow_path)
self.git_repo.index.commit(
"[ZenML GitHub Actions Orchestrator] Add github workflow for "
f"pipeline {deployment.pipeline.name}."
)
self.git_repo.remote().push()
logger.info("Pushed workflow file '%s'", workflow_path)
else:
logger.info(
"Automatically committing and pushing is disabled for this "
"orchestrator. To run the pipeline, you'll have to commit and "
"push the workflow file '%s' manually.\n"
"If you want to update this orchestrator to automatically "
"commit and push in the future, run "
"`zenml orchestrator update %s --push=true`",
workflow_path,
self.name,
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the orchestrator should only run in a clean git repository and the repository is dirty. |
Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
RuntimeError: If the orchestrator should only run in a clean git
repository and the repository is dirty.
"""
if (
not self.config.skip_dirty_repository_check
and self.git_repo.is_dirty(untracked_files=True)
):
raise RuntimeError(
"Trying to run a pipeline from within a dirty (=containing "
"untracked/uncommitted files) git repository."
"If you want this orchestrator to skip the dirty repo check in "
f"the future, run\n `zenml orchestrator update {self.name} "
"--skip_dirty_repository_check=true`"
)
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
secrets_managers
special
Initialization of the GitHub Secrets Manager.
github_secrets_manager
Implementation of the GitHub Secrets Manager.
GitHubSecretsManager (BaseSecretsManager)
Class to interact with the GitHub secrets manager.
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
class GitHubSecretsManager(BaseSecretsManager):
"""Class to interact with the GitHub secrets manager."""
_session: Optional[requests.Session] = None
@property
def config(self) -> GitHubSecretsManagerConfig:
"""Returns the `GitHubSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(GitHubSecretsManagerConfig, self._config)
@property
def post_registration_message(self) -> Optional[str]:
"""Info message regarding GitHub API authentication env variables.
Returns:
The info message.
"""
return AUTHENTICATION_CREDENTIALS_MESSAGE
@property
def session(self) -> requests.Session:
"""Session to send requests to the GitHub API.
Returns:
Session to use for GitHub API calls.
Raises:
RuntimeError: If authentication credentials for the GitHub API are
not set.
"""
if not self._session:
session = requests.Session()
github_username = os.getenv(ENV_GITHUB_USERNAME)
authentication_token = os.getenv(ENV_GITHUB_AUTHENTICATION_TOKEN)
if not github_username or not authentication_token:
raise RuntimeError(
"Missing authentication credentials for GitHub secrets "
"manager. " + AUTHENTICATION_CREDENTIALS_MESSAGE
)
session.auth = HTTPBasicAuth(github_username, authentication_token)
session.headers["Accept"] = "application/vnd.github.v3+json"
self._session = session
return self._session
def _send_request(
self, method: str, resource: Optional[str] = None, **kwargs: Any
) -> requests.Response:
"""Sends an HTTP request to the GitHub API.
Args:
method: Method of the HTTP request that should be sent.
resource: Optional resource to which the request should be sent. If
none is given, the default GitHub API secrets endpoint will be
used.
**kwargs: Will be passed to the `requests` library.
Returns:
HTTP response.
# noqa: DAR402
Raises:
HTTPError: If the request failed due to a client or server error.
"""
url = (
f"https://api.github.com/repos/{self.config.owner}"
f"/{self.config.repository}/actions/secrets"
)
if resource:
url += resource
response = self.session.request(method=method, url=url, **kwargs)
# Raise an exception in case of a client or server error
response.raise_for_status()
return response
def _encrypt_secret(self, secret_value: str) -> Tuple[str, str]:
"""Encrypts a secret value.
This method first fetches a public key from the GitHub API and then uses
this key to encrypt the secret value. This is needed in order to
register GitHub secrets using the API.
Args:
secret_value: Secret value to encrypt.
Returns:
The encrypted secret value and the key id of the GitHub public key.
"""
from nacl.encoding import Base64Encoder
from nacl.public import PublicKey, SealedBox
response_json = self._send_request("GET", resource="/public-key").json()
public_key = PublicKey(
response_json["key"].encode("utf-8"), Base64Encoder
)
sealed_box = SealedBox(public_key)
encrypted_bytes = sealed_box.encrypt(secret_value.encode("utf-8"))
encrypted_string = base64.b64encode(encrypted_bytes).decode("utf-8")
return encrypted_string, cast(str, response_json["key_id"])
def _has_secret(self, secret_name: str) -> bool:
"""Checks whether a secret exists for the given name.
Args:
secret_name: Name of the secret which should be checked.
Returns:
`True` if a secret with the given name exists, `False` otherwise.
"""
secret_name = _convert_secret_name(secret_name, remove_prefix=True)
return secret_name in self.get_all_secret_keys(include_prefix=False)
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
This method only works when called from within a GitHub Actions
environment.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If a secret with this name doesn't exist.
RuntimeError: If not inside a GitHub Actions environments.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
# Raise a KeyError if the secret doesn't exist. We can do that even
# if we're not inside a GitHub Actions environment
if not self._has_secret(secret_name):
raise KeyError(
f"Unable to find secret '{secret_name}'. Please check the "
"GitHub UI to see if a **Repository** secret called "
f"'{full_secret_name}' exists. (ZenML uses the "
f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
"secrets from other GitHub secrets)"
)
if not inside_github_action_environment():
stack_name = Client().active_stack_model.name
commands = [
f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
"zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
"--flavor=local",
"zenml stack update <NEW_STACK_NAME> "
"--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
"zenml stack set <NEW_STACK_NAME>",
f"zenml secrets-manager secret register {secret_name} ...",
]
raise RuntimeError(
"Getting GitHub secrets is only possible within a GitHub "
"Actions workflow. If you need this secret to access "
"stack components locally, you need to "
"register this secret in a different secrets manager. "
"You can do this by running the following commands: \n\n"
+ "\n".join(commands)
)
# If we're running inside an GitHub Actions environment using the a
# workflow generated by the GitHub Actions orchestrator, all ZenML
# secrets stored in the GitHub secrets manager will be accessible as
# environment variables
secret_value = cast(str, os.getenv(full_secret_name))
secret_dict = json.loads(string_utils.b64_decode(secret_value))
schema_class = SecretSchemaClassRegistry.get_class(
secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
)
secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]
return schema_class(name=secret_name, **secret_content)
def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
"""Get all secret keys.
If we're running inside a GitHub Actions environment, this will return
the names of all environment variables starting with a ZenML internal
prefix. Otherwise, this will return all GitHub **Repository** secrets
created by ZenML.
Args:
include_prefix: Whether or not the internal prefix that is used to
differentiate ZenML secrets from other GitHub secrets should be
included in the returned names.
Returns:
List of all secret keys.
"""
if inside_github_action_environment():
potential_secret_keys = list(os.environ)
else:
logger.info(
"Fetching list of secrets for repository %s/%s",
self.config.owner,
self.config.repository,
)
response = self._send_request("GET", params={"per_page": 100})
potential_secret_keys = [
secret_dict["name"]
for secret_dict in response.json()["secrets"]
]
keys = [
_convert_secret_name(key, remove_prefix=not include_prefix)
for key in potential_secret_keys
if key.startswith(GITHUB_SECRET_PREFIX)
]
return keys
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If a secret with this name already exists.
"""
if self._has_secret(secret.name):
raise SecretExistsError(
f"A secret with name '{secret.name}' already exists for this "
"GitHub repository. If you want to register a new value for "
f"this secret, please run `zenml secrets-manager secret delete {secret.name}` "
f"followed by `zenml secrets-manager secret register {secret.name} ...`."
)
secret_dict = {
SECRET_SCHEMA_DICT_KEY: secret.TYPE,
SECRET_CONTENT_DICT_KEY: secret.content,
}
secret_value = string_utils.b64_encode(json.dumps(secret_dict))
encrypted_secret, public_key_id = self._encrypt_secret(
secret_value=secret_value
)
body = {
"encrypted_value": encrypted_secret,
"key_id": public_key_id,
}
full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
self._send_request("PUT", resource=f"/{full_secret_name}", json=body)
def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
NotImplementedError: Always, as this functionality is not possible
using GitHub secrets which doesn't allow us to retrieve the
secret values outside of a GitHub Actions environment.
"""
raise NotImplementedError(
"Updating secrets is not possible with the GitHub secrets manager "
"as it is not possible to retrieve GitHub secrets values outside "
"of a GitHub Actions environment."
)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
self._send_request("DELETE", resource=f"/{full_secret_name}")
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
for secret_name in self.get_all_secret_keys(include_prefix=False):
self.delete_secret(secret_name=secret_name)
config: GitHubSecretsManagerConfig
property
readonly
Returns the GitHubSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
GitHubSecretsManagerConfig |
The configuration. |
post_registration_message: Optional[str]
property
readonly
Info message regarding GitHub API authentication env variables.
Returns:
Type | Description |
---|---|
Optional[str] |
The info message. |
session: Session
property
readonly
Session to send requests to the GitHub API.
Returns:
Type | Description |
---|---|
Session |
Session to use for GitHub API calls. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If authentication credentials for the GitHub API are not set. |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
for secret_name in self.get_all_secret_keys(include_prefix=False):
self.delete_secret(secret_name=secret_name)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to delete. |
required |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
self._send_request("DELETE", resource=f"/{full_secret_name}")
get_all_secret_keys(self, include_prefix=False)
Get all secret keys.
If we're running inside a GitHub Actions environment, this will return the names of all environment variables starting with a ZenML internal prefix. Otherwise, this will return all GitHub Repository secrets created by ZenML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
include_prefix |
bool |
Whether or not the internal prefix that is used to differentiate ZenML secrets from other GitHub secrets should be included in the returned names. |
False |
Returns:
Type | Description |
---|---|
List[str] |
List of all secret keys. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
"""Get all secret keys.
If we're running inside a GitHub Actions environment, this will return
the names of all environment variables starting with a ZenML internal
prefix. Otherwise, this will return all GitHub **Repository** secrets
created by ZenML.
Args:
include_prefix: Whether or not the internal prefix that is used to
differentiate ZenML secrets from other GitHub secrets should be
included in the returned names.
Returns:
List of all secret keys.
"""
if inside_github_action_environment():
potential_secret_keys = list(os.environ)
else:
logger.info(
"Fetching list of secrets for repository %s/%s",
self.config.owner,
self.config.repository,
)
response = self._send_request("GET", params={"per_page": 100})
potential_secret_keys = [
secret_dict["name"]
for secret_dict in response.json()["secrets"]
]
keys = [
_convert_secret_name(key, remove_prefix=not include_prefix)
for key in potential_secret_keys
if key.startswith(GITHUB_SECRET_PREFIX)
]
return keys
get_secret(self, secret_name)
Gets the value of a secret.
This method only works when called from within a GitHub Actions environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to get. |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
If a secret with this name doesn't exist. |
RuntimeError |
If not inside a GitHub Actions environments. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
This method only works when called from within a GitHub Actions
environment.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If a secret with this name doesn't exist.
RuntimeError: If not inside a GitHub Actions environments.
"""
full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
# Raise a KeyError if the secret doesn't exist. We can do that even
# if we're not inside a GitHub Actions environment
if not self._has_secret(secret_name):
raise KeyError(
f"Unable to find secret '{secret_name}'. Please check the "
"GitHub UI to see if a **Repository** secret called "
f"'{full_secret_name}' exists. (ZenML uses the "
f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
"secrets from other GitHub secrets)"
)
if not inside_github_action_environment():
stack_name = Client().active_stack_model.name
commands = [
f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
"zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
"--flavor=local",
"zenml stack update <NEW_STACK_NAME> "
"--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
"zenml stack set <NEW_STACK_NAME>",
f"zenml secrets-manager secret register {secret_name} ...",
]
raise RuntimeError(
"Getting GitHub secrets is only possible within a GitHub "
"Actions workflow. If you need this secret to access "
"stack components locally, you need to "
"register this secret in a different secrets manager. "
"You can do this by running the following commands: \n\n"
+ "\n".join(commands)
)
# If we're running inside an GitHub Actions environment using the a
# workflow generated by the GitHub Actions orchestrator, all ZenML
# secrets stored in the GitHub secrets manager will be accessible as
# environment variables
secret_value = cast(str, os.getenv(full_secret_name))
secret_dict = json.loads(string_utils.b64_decode(secret_value))
schema_class = SecretSchemaClassRegistry.get_class(
secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
)
secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]
return schema_class(name=secret_name, **secret_content)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to register. |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
If a secret with this name already exists. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If a secret with this name already exists.
"""
if self._has_secret(secret.name):
raise SecretExistsError(
f"A secret with name '{secret.name}' already exists for this "
"GitHub repository. If you want to register a new value for "
f"this secret, please run `zenml secrets-manager secret delete {secret.name}` "
f"followed by `zenml secrets-manager secret register {secret.name} ...`."
)
secret_dict = {
SECRET_SCHEMA_DICT_KEY: secret.TYPE,
SECRET_CONTENT_DICT_KEY: secret.content,
}
secret_value = string_utils.b64_encode(json.dumps(secret_dict))
encrypted_secret, public_key_id = self._encrypt_secret(
secret_value=secret_value
)
body = {
"encrypted_value": encrypted_secret,
"key_id": public_key_id,
}
full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
self._send_request("PUT", resource=f"/{full_secret_name}", json=body)
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to update. |
required |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
Always, as this functionality is not possible using GitHub secrets which doesn't allow us to retrieve the secret values outside of a GitHub Actions environment. |
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
NotImplementedError: Always, as this functionality is not possible
using GitHub secrets which doesn't allow us to retrieve the
secret values outside of a GitHub Actions environment.
"""
raise NotImplementedError(
"Updating secrets is not possible with the GitHub secrets manager "
"as it is not possible to retrieve GitHub secrets values outside "
"of a GitHub Actions environment."
)
inside_github_action_environment()
Returns if the current code is executing in a GitHub Actions environment.
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def inside_github_action_environment() -> bool:
"""Returns if the current code is executing in a GitHub Actions environment.
Returns:
`True` if running in a GitHub Actions environment, `False` otherwise.
"""
return os.getenv(ENV_IN_GITHUB_ACTIONS) == "true"
graphviz
special
Initialization of the Graphviz integration.
GraphvizIntegration (Integration)
Definition of Graphviz integration for ZenML.
Source code in zenml/integrations/graphviz/__init__.py
class GraphvizIntegration(Integration):
"""Definition of Graphviz integration for ZenML."""
NAME = GRAPHVIZ
REQUIREMENTS = ["graphviz>=0.17"]
SYSTEM_REQUIREMENTS = {"graphviz": "dot"}
visualizers
special
Initialization of Graphviz visualizers.
pipeline_run_dag_visualizer
Implementation of the Graphviz pipeline run DAG visualizer.
PipelineRunDagVisualizer (BaseVisualizer)
Visualize the lineage of runs in a pipeline.
Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
class PipelineRunDagVisualizer(BaseVisualizer):
"""Visualize the lineage of runs in a pipeline."""
ARTIFACT_DEFAULT_COLOR = "blue"
ARTIFACT_CACHED_COLOR = "green"
ARTIFACT_SHAPE = "box"
ARTIFACT_PREFIX = "artifact_"
STEP_COLOR = "#431D93"
STEP_SHAPE = "ellipse"
STEP_PREFIX = "step_"
FONT = "Roboto"
@abstractmethod
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
"""Creates a pipeline lineage diagram using graphviz.
Args:
object: The pipeline run view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A graphviz digraph object.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
dot = graphviz.Digraph(comment=object.name)
# link the steps together
for step in object.steps:
# add each step as a node
dot.node(
self.STEP_PREFIX + str(step.id),
step.entrypoint_name,
shape=self.STEP_SHAPE,
)
# for each parent of a step, add an edge
for artifact_name, artifact in step.outputs.items():
dot.node(
self.ARTIFACT_PREFIX + str(artifact.id),
f"{artifact_name} \n" f"({artifact.data_type})",
shape=self.ARTIFACT_SHAPE,
)
dot.edge(
self.STEP_PREFIX + str(step.id),
self.ARTIFACT_PREFIX + str(artifact.id),
)
for artifact_name, artifact in step.inputs.items():
dot.edge(
self.ARTIFACT_PREFIX + str(artifact.id),
self.STEP_PREFIX + str(step.id),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
dot.render(filename=f.name, format="png", view=True, cleanup=True)
return dot
visualize(self, object, *args, **kwargs)
Creates a pipeline lineage diagram using graphviz.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineRunView |
The pipeline run view to visualize. |
required |
*args |
Any |
Additional arguments to pass to the visualization. |
() |
**kwargs |
Any |
Additional keyword arguments to pass to the visualization. |
{} |
Returns:
Type | Description |
---|---|
Digraph |
A graphviz digraph object. |
Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
@abstractmethod
def visualize(
self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
"""Creates a pipeline lineage diagram using graphviz.
Args:
object: The pipeline run view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A graphviz digraph object.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
dot = graphviz.Digraph(comment=object.name)
# link the steps together
for step in object.steps:
# add each step as a node
dot.node(
self.STEP_PREFIX + str(step.id),
step.entrypoint_name,
shape=self.STEP_SHAPE,
)
# for each parent of a step, add an edge
for artifact_name, artifact in step.outputs.items():
dot.node(
self.ARTIFACT_PREFIX + str(artifact.id),
f"{artifact_name} \n" f"({artifact.data_type})",
shape=self.ARTIFACT_SHAPE,
)
dot.edge(
self.STEP_PREFIX + str(step.id),
self.ARTIFACT_PREFIX + str(artifact.id),
)
for artifact_name, artifact in step.inputs.items():
dot.edge(
self.ARTIFACT_PREFIX + str(artifact.id),
self.STEP_PREFIX + str(step.id),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
dot.render(filename=f.name, format="png", view=True, cleanup=True)
return dot
great_expectations
special
Great Expectation integration for ZenML.
The Great Expectations integration enables you to use Great Expectations as a way of profiling and validating your data.
GreatExpectationsIntegration (Integration)
Definition of Great Expectations integration for ZenML.
Source code in zenml/integrations/great_expectations/__init__.py
class GreatExpectationsIntegration(Integration):
"""Definition of Great Expectations integration for ZenML."""
NAME = GREAT_EXPECTATIONS
REQUIREMENTS = [
"great-expectations~=0.15.11",
]
@staticmethod
def activate() -> None:
"""Activate the Great Expectations integration."""
from zenml.integrations.great_expectations import materializers # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.great_expectations.flavors import (
GreatExpectationsDataValidatorFlavor,
)
return [GreatExpectationsDataValidatorFlavor]
activate()
staticmethod
Activate the Great Expectations integration.
Source code in zenml/integrations/great_expectations/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Great Expectations integration."""
from zenml.integrations.great_expectations import materializers # noqa
flavors()
classmethod
Declare the stack component flavors for the Great Expectations integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/great_expectations/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.great_expectations.flavors import (
GreatExpectationsDataValidatorFlavor,
)
return [GreatExpectationsDataValidatorFlavor]
data_validators
special
Initialization of the Great Expectations data validator for ZenML.
ge_data_validator
Implementation of the Great Expectations data validator.
GreatExpectationsDataValidator (BaseDataValidator)
Great Expectations data validator stack component.
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
class GreatExpectationsDataValidator(BaseDataValidator):
"""Great Expectations data validator stack component."""
_context: BaseDataContext = None
_context_config: Optional[Dict[str, Any]] = None
@property
def config(self) -> GreatExpectationsDataValidatorConfig:
"""Returns the `GreatExpectationsDataValidatorConfig` config.
Returns:
The configuration.
"""
return cast(GreatExpectationsDataValidatorConfig, self._config)
@classmethod
def get_data_context(cls) -> BaseDataContext:
"""Get the Great Expectations data context managed by ZenML.
Call this method to retrieve the data context managed by ZenML
through the active Great Expectations data validator stack component.
Returns:
A Great Expectations data context managed by ZenML as configured
through the active data validator stack component.
"""
data_validator = cast(
"GreatExpectationsDataValidator", cls.get_active_data_validator()
)
return data_validator.data_context
@property
def context_config(self) -> Optional[Dict[str, Any]]:
"""Get the Great Expectations data context configuration.
The first time the context config is loaded from the stack component
config, it is converted from JSON/YAML string format to a dict.
Raises:
ValueError: If the context_config value is not a valid JSON/YAML or
if the GE configuration extracted from it fails GE validation.
Returns:
A dictionary with the GE data context configuration.
"""
# If the context config is already loaded, return it
if self._context_config is not None:
return self._context_config
# Otherwise, load it from the stack component config
context_config = self.config.context_config
if context_config is None:
return None
if isinstance(context_config, dict):
self._context_config = context_config
return self._context_config
# If the context config is a string, try to parse it as JSON/YAML
try:
context_config_dict = yaml.safe_load(context_config)
except yaml.parser.ParserError as e:
raise ValueError(
f"Malformed `context_config` value. Only JSON and YAML "
f"formats are supported: {str(e)}"
)
# Validate that the context config is a valid GE config
try:
context_config = DataContextConfig(**context_config_dict)
BaseDataContext(project_config=context_config)
except Exception as e:
raise ValueError(f"Invalid `context_config` value: {str(e)}")
self._context_config = cast(Dict[str, Any], context_config_dict)
return self._context_config
@property
def local_path(self) -> Optional[str]:
"""Return a local path where this component stores information.
If an existing local GE data context is used, it is
interpreted as a local path that needs to be accessible in
all runtime environments.
Returns:
The local path where this component stores information.
"""
return self.config.context_root_dir
def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
"""Generate a Great Expectations store configuration.
Args:
class_name: The store class name
prefix: The path prefix for the ZenML store configuration
Returns:
A dictionary with the GE store configuration.
"""
return {
"class_name": class_name,
"store_backend": {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.id)}/{prefix}",
},
}
def get_data_docs_config(
self, prefix: str, local: bool = False
) -> Dict[str, Any]:
"""Generate Great Expectations data docs configuration.
Args:
prefix: The path prefix for the ZenML data docs configuration
local: Whether the data docs site is local or remote.
Returns:
A dictionary with the GE data docs site configuration.
"""
if local:
store_backend = {
"class_name": "TupleFilesystemStoreBackend",
"base_directory": f"{self.root_directory}/{prefix}",
}
else:
store_backend = {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.id)}/{prefix}",
}
return {
"class_name": "SiteBuilder",
"store_backend": store_backend,
"site_index_builder": {
"class_name": "DefaultSiteIndexBuilder",
},
}
@property
def data_context(self) -> BaseDataContext:
"""Returns the Great Expectations data context configured for this component.
Returns:
The Great Expectations data context configured for this component.
"""
if not self._context:
expectations_store_name = "zenml_expectations_store"
validations_store_name = "zenml_validations_store"
checkpoint_store_name = "zenml_checkpoint_store"
profiler_store_name = "zenml_profiler_store"
evaluation_parameter_store_name = "evaluation_parameter_store"
zenml_context_config = dict(
stores={
expectations_store_name: self.get_store_config(
"ExpectationsStore", "expectations"
),
validations_store_name: self.get_store_config(
"ValidationsStore", "validations"
),
checkpoint_store_name: self.get_store_config(
"CheckpointStore", "checkpoints"
),
profiler_store_name: self.get_store_config(
"ProfilerStore", "profilers"
),
evaluation_parameter_store_name: {
"class_name": "EvaluationParameterStore"
},
},
expectations_store_name=expectations_store_name,
validations_store_name=validations_store_name,
checkpoint_store_name=checkpoint_store_name,
profiler_store_name=profiler_store_name,
evaluation_parameter_store_name=evaluation_parameter_store_name,
data_docs_sites={
"zenml_artifact_store": self.get_data_docs_config(
"data_docs"
)
},
)
configure_zenml_stores = self.config.configure_zenml_stores
if self.config.context_root_dir:
# initialize the local data context, if a local path was
# configured
self._context = DataContext(self.config.context_root_dir)
else:
# create an in-memory data context configuration that is not
# backed by a local YAML file (see https://docs.greatexpectations.io/docs/guides/setup/configuring_data_contexts/how_to_instantiate_a_data_context_without_a_yml_file/).
if self.context_config:
context_config = DataContextConfig(**self.context_config)
else:
context_config = DataContextConfig(**zenml_context_config)
# skip adding the stores after initialization, as they are
# already baked in the initial configuration
configure_zenml_stores = False
self._context = BaseDataContext(project_config=context_config)
if configure_zenml_stores:
self._context.config.expectations_store_name = (
expectations_store_name
)
self._context.config.validations_store_name = (
validations_store_name
)
self._context.config.checkpoint_store_name = (
checkpoint_store_name
)
self._context.config.profiler_store_name = profiler_store_name
self._context.config.evaluation_parameter_store_name = (
evaluation_parameter_store_name
)
for store_name, store_config in zenml_context_config[ # type: ignore[attr-defined]
"stores"
].items():
self._context.add_store(
store_name=store_name,
store_config=store_config,
)
for site_name, site_config in zenml_context_config[ # type: ignore[attr-defined]
"data_docs_sites"
].items():
self._context.config.data_docs_sites[
site_name
] = site_config
if self.config.configure_local_docs:
client = Client(skip_client_check=True) # type: ignore[call-arg]
artifact_store = client.active_stack.artifact_store
if artifact_store.flavor != "local":
self._context.config.data_docs_sites[
"zenml_local"
] = self.get_data_docs_config("data_docs", local=True)
return self._context
@property
def root_directory(self) -> str:
"""Returns path to the root directory for all local files concerning this data validator.
Returns:
Path to the root directory.
"""
path = os.path.join(
io_utils.get_global_config_directory(),
self.flavor,
str(self.id),
)
if not os.path.exists(path):
fileio.makedirs(path)
return path
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
profile_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
profiler_kwargs: Optional[Dict[str, Any]] = None,
overwrite_existing_suite: bool = True,
**kwargs: Any,
) -> ExpectationSuite:
"""Infer a Great Expectation Expectation Suite from a given dataset.
This Great Expectations specific data profiling method implementation
builds an Expectation Suite automatically by running a
UserConfigurableProfiler on an input dataset [as covered in the official
GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
Args:
dataset: The dataset from which the expectation suite will be
inferred.
comparison_dataset: Optional dataset used to generate data
comparison (i.e. data drift) profiles. Not supported by the
Great Expectation data validator.
profile_list: Optional list identifying the categories of data
profiles to be generated. Not supported by the Great Expectation
data validator.
expectation_suite_name: The name of the expectation suite to create
or update. If not supplied, a unique name will be generated from
the current pipeline and step name, if running in the context of
a pipeline step.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
profiler_kwargs: A dictionary of custom keyword arguments to pass to
the profiler.
overwrite_existing_suite: Whether to overwrite an existing
expectation suite, if one exists with that name.
kwargs: Additional keyword arguments (unused).
Returns:
The inferred Expectation Suite.
Raises:
ValueError: if an `expectation_suite_name` value is not supplied and
a name for the expectation suite cannot be generated from the
current step name and pipeline name.
"""
context = self.data_context
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data profiling. Silently ignoring the supplied dataset "
)
if not expectation_suite_name:
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
pipeline_name = step_env.pipeline_name
step_name = step_env.step_name
expectation_suite_name = f"{pipeline_name}_{step_name}"
except KeyError:
raise ValueError(
"A expectation suite name is required when not running in "
"the context of a pipeline step."
)
suite_exists = False
if context.expectations_store.has_key( # noqa
ExpectationSuiteIdentifier(expectation_suite_name)
):
suite_exists = True
suite = context.get_expectation_suite(expectation_suite_name)
if not overwrite_existing_suite:
logger.info(
f"Expectation Suite `{expectation_suite_name}` "
f"already exists and `overwrite_existing_suite` is not set "
f"in the step configuration. Skipping re-running the "
f"profiler."
)
return suite
batch_request = create_batch_request(context, dataset, data_asset_name)
try:
if suite_exists:
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=expectation_suite_name,
)
else:
validator = context.get_validator(
batch_request=batch_request,
create_expectation_suite_with_name=expectation_suite_name,
)
profiler = UserConfigurableProfiler(
profile_dataset=validator, **profiler_kwargs
)
suite = profiler.build_suite()
context.save_expectation_suite(
expectation_suite=suite,
expectation_suite_name=expectation_suite_name,
)
context.build_data_docs()
finally:
context.delete_datasource(batch_request.datasource_name)
return suite
def data_validation(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
action_list: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> CheckpointResult:
"""Great Expectations data validation.
This Great Expectations specific data validation method
implementation validates an input dataset against an Expectation Suite
(the GE definition of a profile) [as covered in the official GE
documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
Args:
dataset: The dataset to validate.
comparison_dataset: Optional dataset used to run data
comparison (i.e. data drift) checks. Not supported by the
Great Expectation data validator.
check_list: Optional list identifying the data validation checks to
be performed. Not supported by the Great Expectations data
validator.
expectation_suite_name: The name of the expectation suite to use to
validate the dataset. A value must be provided.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run after
the validation check.
kwargs: Additional keyword arguments (unused).
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
ValueError: if the `expectation_suite_name` argument is omitted.
"""
if not expectation_suite_name:
raise ValueError("Missing expectation_suite_name argument value.")
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data validation. Silently ignoring the supplied dataset "
)
try:
# get pipeline name, step name and run id
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
except KeyError:
# if not running inside a pipeline step, use random values
run_id = f"pipeline_{random_str(5)}"
step_name = f"step_{random_str(5)}"
context = self.data_context
checkpoint_name = f"{run_id}_{step_name}"
batch_request = create_batch_request(context, dataset, data_asset_name)
action_list = action_list or [
{
"name": "store_validation_result",
"action": {"class_name": "StoreValidationResultAction"},
},
{
"name": "store_evaluation_params",
"action": {"class_name": "StoreEvaluationParametersAction"},
},
{
"name": "update_data_docs",
"action": {"class_name": "UpdateDataDocsAction"},
},
]
checkpoint_config = {
"name": checkpoint_name,
"run_name_template": f"{run_id}",
"config_version": 1,
"class_name": "Checkpoint",
"expectation_suite_name": expectation_suite_name,
"action_list": action_list,
}
context.add_checkpoint(**checkpoint_config)
try:
results = context.run_checkpoint(
checkpoint_name=checkpoint_name,
validations=[{"batch_request": batch_request}],
)
finally:
context.delete_datasource(batch_request.datasource_name)
context.delete_checkpoint(checkpoint_name)
return results
config: GreatExpectationsDataValidatorConfig
property
readonly
Returns the GreatExpectationsDataValidatorConfig
config.
Returns:
Type | Description |
---|---|
GreatExpectationsDataValidatorConfig |
The configuration. |
context_config: Optional[Dict[str, Any]]
property
readonly
Get the Great Expectations data context configuration.
The first time the context config is loaded from the stack component config, it is converted from JSON/YAML string format to a dict.
Exceptions:
Type | Description |
---|---|
ValueError |
If the context_config value is not a valid JSON/YAML or if the GE configuration extracted from it fails GE validation. |
Returns:
Type | Description |
---|---|
Optional[Dict[str, Any]] |
A dictionary with the GE data context configuration. |
data_context: BaseDataContext
property
readonly
Returns the Great Expectations data context configured for this component.
Returns:
Type | Description |
---|---|
BaseDataContext |
The Great Expectations data context configured for this component. |
local_path: Optional[str]
property
readonly
Return a local path where this component stores information.
If an existing local GE data context is used, it is interpreted as a local path that needs to be accessible in all runtime environments.
Returns:
Type | Description |
---|---|
Optional[str] |
The local path where this component stores information. |
root_directory: str
property
readonly
Returns path to the root directory for all local files concerning this data validator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, expectation_suite_name=None, data_asset_name=None, profiler_kwargs=None, overwrite_existing_suite=True, **kwargs)
Infer a Great Expectation Expectation Suite from a given dataset.
This Great Expectations specific data profiling method implementation builds an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset from which the expectation suite will be inferred. |
required |
comparison_dataset |
Optional[Any] |
Optional dataset used to generate data comparison (i.e. data drift) profiles. Not supported by the Great Expectation data validator. |
None |
profile_list |
Optional[Sequence[str]] |
Optional list identifying the categories of data profiles to be generated. Not supported by the Great Expectation data validator. |
None |
expectation_suite_name |
Optional[str] |
The name of the expectation suite to create or update. If not supplied, a unique name will be generated from the current pipeline and step name, if running in the context of a pipeline step. |
None |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
None |
profiler_kwargs |
Optional[Dict[str, Any]] |
A dictionary of custom keyword arguments to pass to the profiler. |
None |
overwrite_existing_suite |
bool |
Whether to overwrite an existing expectation suite, if one exists with that name. |
True |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
ExpectationSuite |
The inferred Expectation Suite. |
Exceptions:
Type | Description |
---|---|
ValueError |
if an |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
profile_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
profiler_kwargs: Optional[Dict[str, Any]] = None,
overwrite_existing_suite: bool = True,
**kwargs: Any,
) -> ExpectationSuite:
"""Infer a Great Expectation Expectation Suite from a given dataset.
This Great Expectations specific data profiling method implementation
builds an Expectation Suite automatically by running a
UserConfigurableProfiler on an input dataset [as covered in the official
GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
Args:
dataset: The dataset from which the expectation suite will be
inferred.
comparison_dataset: Optional dataset used to generate data
comparison (i.e. data drift) profiles. Not supported by the
Great Expectation data validator.
profile_list: Optional list identifying the categories of data
profiles to be generated. Not supported by the Great Expectation
data validator.
expectation_suite_name: The name of the expectation suite to create
or update. If not supplied, a unique name will be generated from
the current pipeline and step name, if running in the context of
a pipeline step.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
profiler_kwargs: A dictionary of custom keyword arguments to pass to
the profiler.
overwrite_existing_suite: Whether to overwrite an existing
expectation suite, if one exists with that name.
kwargs: Additional keyword arguments (unused).
Returns:
The inferred Expectation Suite.
Raises:
ValueError: if an `expectation_suite_name` value is not supplied and
a name for the expectation suite cannot be generated from the
current step name and pipeline name.
"""
context = self.data_context
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data profiling. Silently ignoring the supplied dataset "
)
if not expectation_suite_name:
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
pipeline_name = step_env.pipeline_name
step_name = step_env.step_name
expectation_suite_name = f"{pipeline_name}_{step_name}"
except KeyError:
raise ValueError(
"A expectation suite name is required when not running in "
"the context of a pipeline step."
)
suite_exists = False
if context.expectations_store.has_key( # noqa
ExpectationSuiteIdentifier(expectation_suite_name)
):
suite_exists = True
suite = context.get_expectation_suite(expectation_suite_name)
if not overwrite_existing_suite:
logger.info(
f"Expectation Suite `{expectation_suite_name}` "
f"already exists and `overwrite_existing_suite` is not set "
f"in the step configuration. Skipping re-running the "
f"profiler."
)
return suite
batch_request = create_batch_request(context, dataset, data_asset_name)
try:
if suite_exists:
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=expectation_suite_name,
)
else:
validator = context.get_validator(
batch_request=batch_request,
create_expectation_suite_with_name=expectation_suite_name,
)
profiler = UserConfigurableProfiler(
profile_dataset=validator, **profiler_kwargs
)
suite = profiler.build_suite()
context.save_expectation_suite(
expectation_suite=suite,
expectation_suite_name=expectation_suite_name,
)
context.build_data_docs()
finally:
context.delete_datasource(batch_request.datasource_name)
return suite
data_validation(self, dataset, comparison_dataset=None, check_list=None, expectation_suite_name=None, data_asset_name=None, action_list=None, **kwargs)
Great Expectations data validation.
This Great Expectations specific data validation method implementation validates an input dataset against an Expectation Suite (the GE definition of a profile) as covered in the official GE documentation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset to validate. |
required |
comparison_dataset |
Optional[Any] |
Optional dataset used to run data comparison (i.e. data drift) checks. Not supported by the Great Expectation data validator. |
None |
check_list |
Optional[Sequence[str]] |
Optional list identifying the data validation checks to be performed. Not supported by the Great Expectations data validator. |
None |
expectation_suite_name |
Optional[str] |
The name of the expectation suite to use to validate the dataset. A value must be provided. |
None |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
None |
action_list |
Optional[List[Dict[str, Any]]] |
A list of additional Great Expectations actions to run after the validation check. |
None |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
CheckpointResult |
The Great Expectations validation (checkpoint) result. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def data_validation(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
expectation_suite_name: Optional[str] = None,
data_asset_name: Optional[str] = None,
action_list: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> CheckpointResult:
"""Great Expectations data validation.
This Great Expectations specific data validation method
implementation validates an input dataset against an Expectation Suite
(the GE definition of a profile) [as covered in the official GE
documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
Args:
dataset: The dataset to validate.
comparison_dataset: Optional dataset used to run data
comparison (i.e. data drift) checks. Not supported by the
Great Expectation data validator.
check_list: Optional list identifying the data validation checks to
be performed. Not supported by the Great Expectations data
validator.
expectation_suite_name: The name of the expectation suite to use to
validate the dataset. A value must be provided.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run after
the validation check.
kwargs: Additional keyword arguments (unused).
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
ValueError: if the `expectation_suite_name` argument is omitted.
"""
if not expectation_suite_name:
raise ValueError("Missing expectation_suite_name argument value.")
if comparison_dataset is not None:
logger.warning(
"A comparison dataset is not required by Great Expectations "
"to do data validation. Silently ignoring the supplied dataset "
)
try:
# get pipeline name, step name and run id
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
except KeyError:
# if not running inside a pipeline step, use random values
run_id = f"pipeline_{random_str(5)}"
step_name = f"step_{random_str(5)}"
context = self.data_context
checkpoint_name = f"{run_id}_{step_name}"
batch_request = create_batch_request(context, dataset, data_asset_name)
action_list = action_list or [
{
"name": "store_validation_result",
"action": {"class_name": "StoreValidationResultAction"},
},
{
"name": "store_evaluation_params",
"action": {"class_name": "StoreEvaluationParametersAction"},
},
{
"name": "update_data_docs",
"action": {"class_name": "UpdateDataDocsAction"},
},
]
checkpoint_config = {
"name": checkpoint_name,
"run_name_template": f"{run_id}",
"config_version": 1,
"class_name": "Checkpoint",
"expectation_suite_name": expectation_suite_name,
"action_list": action_list,
}
context.add_checkpoint(**checkpoint_config)
try:
results = context.run_checkpoint(
checkpoint_name=checkpoint_name,
validations=[{"batch_request": batch_request}],
)
finally:
context.delete_datasource(batch_request.datasource_name)
context.delete_checkpoint(checkpoint_name)
return results
get_data_context()
classmethod
Get the Great Expectations data context managed by ZenML.
Call this method to retrieve the data context managed by ZenML through the active Great Expectations data validator stack component.
Returns:
Type | Description |
---|---|
BaseDataContext |
A Great Expectations data context managed by ZenML as configured through the active data validator stack component. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
@classmethod
def get_data_context(cls) -> BaseDataContext:
"""Get the Great Expectations data context managed by ZenML.
Call this method to retrieve the data context managed by ZenML
through the active Great Expectations data validator stack component.
Returns:
A Great Expectations data context managed by ZenML as configured
through the active data validator stack component.
"""
data_validator = cast(
"GreatExpectationsDataValidator", cls.get_active_data_validator()
)
return data_validator.data_context
get_data_docs_config(self, prefix, local=False)
Generate Great Expectations data docs configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prefix |
str |
The path prefix for the ZenML data docs configuration |
required |
local |
bool |
Whether the data docs site is local or remote. |
False |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the GE data docs site configuration. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_data_docs_config(
self, prefix: str, local: bool = False
) -> Dict[str, Any]:
"""Generate Great Expectations data docs configuration.
Args:
prefix: The path prefix for the ZenML data docs configuration
local: Whether the data docs site is local or remote.
Returns:
A dictionary with the GE data docs site configuration.
"""
if local:
store_backend = {
"class_name": "TupleFilesystemStoreBackend",
"base_directory": f"{self.root_directory}/{prefix}",
}
else:
store_backend = {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.id)}/{prefix}",
}
return {
"class_name": "SiteBuilder",
"store_backend": store_backend,
"site_index_builder": {
"class_name": "DefaultSiteIndexBuilder",
},
}
get_store_config(self, class_name, prefix)
Generate a Great Expectations store configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_name |
str |
The store class name |
required |
prefix |
str |
The path prefix for the ZenML store configuration |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the GE store configuration. |
Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
"""Generate a Great Expectations store configuration.
Args:
class_name: The store class name
prefix: The path prefix for the ZenML store configuration
Returns:
A dictionary with the GE store configuration.
"""
return {
"class_name": class_name,
"store_backend": {
"module_name": ZenMLArtifactStoreBackend.__module__,
"class_name": ZenMLArtifactStoreBackend.__name__,
"prefix": f"{str(self.id)}/{prefix}",
},
}
flavors
special
Great Expectations integration flavors.
great_expectations_data_validator_flavor
Great Expectations data validator flavor.
GreatExpectationsDataValidatorConfig (BaseDataValidatorConfig)
pydantic-model
Config for the Great Expectations data validator.
Attributes:
Name | Type | Description |
---|---|---|
context_root_dir |
Optional[str] |
location of an already initialized Great Expectations data context. If configured, the data validator will only be usable with local orchestrators. |
context_config |
Optional[Dict[str, Any]] |
in-line Great Expectations data context configuration. |
configure_zenml_stores |
bool |
if set, ZenML will automatically configure
stores that use the Artifact Store as a backend. If neither
|
configure_local_docs |
bool |
configure a local data docs site where Great Expectations docs are generated and can be visualized locally. |
Source code in zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py
class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
"""Config for the Great Expectations data validator.
Attributes:
context_root_dir: location of an already initialized Great Expectations
data context. If configured, the data validator will only be usable
with local orchestrators.
context_config: in-line Great Expectations data context configuration.
configure_zenml_stores: if set, ZenML will automatically configure
stores that use the Artifact Store as a backend. If neither
`context_root_dir` nor `context_config` are set, this is the default
behavior.
configure_local_docs: configure a local data docs site where Great
Expectations docs are generated and can be visualized locally.
"""
context_root_dir: Optional[str] = None
context_config: Optional[Dict[str, Any]] = None
configure_zenml_stores: bool = False
configure_local_docs: bool = True
@validator("context_root_dir")
def _ensure_valid_context_root_dir(
cls, context_root_dir: Optional[str] = None
) -> Optional[str]:
"""Ensures that the root directory is an absolute path and points to an existing path.
Args:
context_root_dir: The context_root_dir value to validate.
Returns:
The context_root_dir if it is valid.
Raises:
ValueError: If the context_root_dir is not valid.
"""
if context_root_dir:
context_root_dir = os.path.abspath(context_root_dir)
if not fileio.exists(context_root_dir):
raise ValueError(
f"The Great Expectations context_root_dir value doesn't "
f"point to an existing data context path: {context_root_dir}"
)
return context_root_dir
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
This designation is used to determine if the stack component can be
shared with other users or if it is only usable on the local host.
Returns:
True if this config is for a local component, False otherwise.
"""
# If an existing local GE data context is used, it is
# interpreted as a local path that needs to be accessible in
# all runtime environments.
return self.context_root_dir is not None
is_local: bool
property
readonly
Checks if this stack component is running locally.
This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
GreatExpectationsDataValidatorFlavor (BaseDataValidatorFlavor)
Great Expectations data validator flavor.
Source code in zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py
class GreatExpectationsDataValidatorFlavor(BaseDataValidatorFlavor):
"""Great Expectations data validator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR
@property
def config_class(self) -> Type[GreatExpectationsDataValidatorConfig]:
"""Returns `GreatExpectationsDataValidatorConfig` config class.
Returns:
The config class.
"""
return GreatExpectationsDataValidatorConfig
@property
def implementation_class(self) -> Type["GreatExpectationsDataValidator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.great_expectations.data_validators import (
GreatExpectationsDataValidator,
)
return GreatExpectationsDataValidator
config_class: Type[zenml.integrations.great_expectations.flavors.great_expectations_data_validator_flavor.GreatExpectationsDataValidatorConfig]
property
readonly
Returns GreatExpectationsDataValidatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.great_expectations.flavors.great_expectations_data_validator_flavor.GreatExpectationsDataValidatorConfig] |
The config class. |
implementation_class: Type[GreatExpectationsDataValidator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GreatExpectationsDataValidator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
ge_store_backend
Great Expectations store plugin for ZenML.
ZenMLArtifactStoreBackend (TupleStoreBackend)
Great Expectations store backend that uses the active ZenML Artifact Store as a store.
Source code in zenml/integrations/great_expectations/ge_store_backend.py
class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
"""Great Expectations store backend that uses the active ZenML Artifact Store as a store."""
def __init__(
self,
prefix: str = "",
**kwargs: Any,
) -> None:
"""Create a Great Expectations ZenML store backend instance.
Args:
prefix: Subpath prefix to use for this store backend.
kwargs: Additional keyword arguments passed by the Great Expectations
core. These are transparently passed to the `TupleStoreBackend`
constructor.
"""
super().__init__(**kwargs)
client = Client(skip_client_check=True) # type: ignore[call-arg]
artifact_store = client.active_stack.artifact_store
self.root_path = os.path.join(artifact_store.path, "great_expectations")
# extract the protocol used in the artifact store root path
protocols = [
scheme
for scheme in artifact_store.config.SUPPORTED_SCHEMES
if self.root_path.startswith(scheme)
]
if protocols:
self.proto = protocols[0]
else:
self.proto = ""
if prefix:
if self.platform_specific_separator:
prefix = prefix.strip(os.sep)
prefix = prefix.strip("/")
self.prefix = prefix
# Initialize with store_backend_id if not part of an HTMLSiteStore
if not self._suppress_store_backend_id:
_ = self.store_backend_id
self._config = {
"prefix": prefix,
"module_name": self.__class__.__module__,
"class_name": self.__class__.__name__,
}
self._config.update(kwargs)
filter_properties_dict(
properties=self._config, clean_falsy=True, inplace=True
)
def _build_object_path(
self, key: Tuple[str, ...], is_prefix: bool = False
) -> str:
"""Build a filepath corresponding to an object key.
Args:
key: Great Expectation object key.
is_prefix: If True, the key will be interpreted as a prefix instead
of a full key identifier.
Returns:
The file path pointing to where the object is stored.
"""
if not isinstance(key, tuple):
key = key.to_tuple() # type: ignore[attr-defined]
if not is_prefix:
object_relative_path = self._convert_key_to_filepath(key)
elif key:
object_relative_path = os.path.join(*key)
else:
object_relative_path = ""
if self.prefix:
object_key = os.path.join(self.prefix, object_relative_path)
else:
object_key = object_relative_path
return os.path.join(self.root_path, object_key)
def _get(self, key: Tuple[str, ...]) -> str:
"""Get the value of an object from the store.
Args:
key: object key identifier.
Raises:
InvalidKeyError: if the key doesn't point to an existing object.
Returns:
str: the object's contents
"""
filepath: str = self._build_object_path(key)
if fileio.exists(filepath):
contents = io_utils.read_file_contents_as_string(filepath).rstrip(
"\n"
)
else:
raise InvalidKeyError(
f"Unable to retrieve object from {self.__class__.__name__} with "
f"the following Key: {str(filepath)}"
)
return contents
def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str:
"""Set the value of an object in the store.
Args:
key: object key identifier.
value: object value to set.
kwargs: additional keyword arguments (ignored).
Returns:
The file path where the object was stored.
"""
filepath: str = self._build_object_path(key)
if not io_utils.is_remote(filepath):
parent_dir = str(Path(filepath).parent)
os.makedirs(parent_dir, exist_ok=True)
with fileio.open(filepath, "wb") as outfile:
if isinstance(value, str):
outfile.write(value.encode("utf-8"))
else:
outfile.write(value)
return filepath
def _move(
self,
source_key: Tuple[str, ...],
dest_key: Tuple[str, ...],
**kwargs: Any,
) -> None:
"""Associate an object with a different key in the store.
Args:
source_key: current object key identifier.
dest_key: new object key identifier.
kwargs: additional keyword arguments (ignored).
"""
source_path = self._build_object_path(source_key)
dest_path = self._build_object_path(dest_key)
if fileio.exists(source_path):
if not io_utils.is_remote(dest_path):
parent_dir = str(Path(dest_path).parent)
os.makedirs(parent_dir, exist_ok=True)
fileio.rename(source_path, dest_path, overwrite=True)
def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
"""List the keys of all objects identified by a partial key.
Args:
prefix: partial object key identifier.
Returns:
List of keys identifying all objects present in the store that
match the input partial key.
"""
key_list = []
list_path = self._build_object_path(prefix, is_prefix=True)
root_path = self._build_object_path(tuple(), is_prefix=True)
for root, dirs, files in fileio.walk(list_path):
for file_ in files:
filepath = os.path.relpath(
os.path.join(str(root), str(file_)), root_path
)
if self.filepath_prefix and not filepath.startswith(
self.filepath_prefix
):
continue
elif self.filepath_suffix and not filepath.endswith(
self.filepath_suffix
):
continue
key = self._convert_filepath_to_key(filepath)
if key and not self.is_ignored_key(key):
key_list.append(key)
return key_list
def remove_key(self, key: Tuple[str, ...]) -> bool:
"""Delete an object from the store.
Args:
key: object key identifier.
Returns:
True if the object existed in the store and was removed, otherwise
False.
"""
filepath: str = self._build_object_path(key)
if fileio.exists(filepath):
fileio.remove(filepath)
if not io_utils.is_remote(filepath):
parent_dir = str(Path(filepath).parent)
self.rrmdir(self.root_path, str(parent_dir))
return True
return False
def _has_key(self, key: Tuple[str, ...]) -> bool:
"""Check if an object is present in the store.
Args:
key: object key identifier.
Returns:
True if the object is present in the store, otherwise False.
"""
filepath: str = self._build_object_path(key)
result = fileio.exists(filepath)
return result
def get_url_for_key(
self, key: Tuple[str, ...], protocol: Optional[str] = None
) -> str:
"""Get the URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The URL of the object in the store.
"""
filepath = self._build_object_path(key)
if not protocol and not io_utils.is_remote(filepath):
protocol = "file:"
if protocol:
filepath = filepath.replace(self.proto, f"{protocol}//", 1)
return filepath
def get_public_url_for_key(
self, key: str, protocol: Optional[str] = None
) -> str:
"""Get the public URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The public URL where the object can be accessed.
Raises:
StoreBackendError: if a `base_public_path` attribute was not
configured for the store.
"""
if not self.base_public_path:
raise StoreBackendError(
f"Error: No base_public_path was configured! A public URL was "
f"requested but `base_public_path` was not configured for the "
f"{self.__class__.__name__}"
)
filepath = self._convert_key_to_filepath(key)
public_url = self.base_public_path + filepath.replace(self.proto, "")
return cast(str, public_url)
@staticmethod
def rrmdir(start_path: str, end_path: str) -> None:
"""Recursively removes empty dirs between start_path and end_path inclusive.
Args:
start_path: Directory to use as a starting point.
end_path: Directory to use as a destination point.
"""
while not os.listdir(end_path) and start_path != end_path:
os.rmdir(end_path)
end_path = os.path.dirname(end_path)
@property
def config(self) -> Dict[str, Any]:
"""Get the store configuration.
Returns:
The store configuration.
"""
return self._config
config: Dict[str, Any]
property
readonly
Get the store configuration.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The store configuration. |
__init__(self, prefix='', **kwargs)
special
Create a Great Expectations ZenML store backend instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prefix |
str |
Subpath prefix to use for this store backend. |
'' |
kwargs |
Any |
Additional keyword arguments passed by the Great Expectations
core. These are transparently passed to the |
{} |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def __init__(
self,
prefix: str = "",
**kwargs: Any,
) -> None:
"""Create a Great Expectations ZenML store backend instance.
Args:
prefix: Subpath prefix to use for this store backend.
kwargs: Additional keyword arguments passed by the Great Expectations
core. These are transparently passed to the `TupleStoreBackend`
constructor.
"""
super().__init__(**kwargs)
client = Client(skip_client_check=True) # type: ignore[call-arg]
artifact_store = client.active_stack.artifact_store
self.root_path = os.path.join(artifact_store.path, "great_expectations")
# extract the protocol used in the artifact store root path
protocols = [
scheme
for scheme in artifact_store.config.SUPPORTED_SCHEMES
if self.root_path.startswith(scheme)
]
if protocols:
self.proto = protocols[0]
else:
self.proto = ""
if prefix:
if self.platform_specific_separator:
prefix = prefix.strip(os.sep)
prefix = prefix.strip("/")
self.prefix = prefix
# Initialize with store_backend_id if not part of an HTMLSiteStore
if not self._suppress_store_backend_id:
_ = self.store_backend_id
self._config = {
"prefix": prefix,
"module_name": self.__class__.__module__,
"class_name": self.__class__.__name__,
}
self._config.update(kwargs)
filter_properties_dict(
properties=self._config, clean_falsy=True, inplace=True
)
get_public_url_for_key(self, key, protocol=None)
Get the public URL of an object in the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
object key identifier. |
required |
protocol |
Optional[str] |
optional protocol to use instead of the store protocol. |
None |
Returns:
Type | Description |
---|---|
str |
The public URL where the object can be accessed. |
Exceptions:
Type | Description |
---|---|
StoreBackendError |
if a |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_public_url_for_key(
self, key: str, protocol: Optional[str] = None
) -> str:
"""Get the public URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The public URL where the object can be accessed.
Raises:
StoreBackendError: if a `base_public_path` attribute was not
configured for the store.
"""
if not self.base_public_path:
raise StoreBackendError(
f"Error: No base_public_path was configured! A public URL was "
f"requested but `base_public_path` was not configured for the "
f"{self.__class__.__name__}"
)
filepath = self._convert_key_to_filepath(key)
public_url = self.base_public_path + filepath.replace(self.proto, "")
return cast(str, public_url)
get_url_for_key(self, key, protocol=None)
Get the URL of an object in the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Tuple[str, ...] |
object key identifier. |
required |
protocol |
Optional[str] |
optional protocol to use instead of the store protocol. |
None |
Returns:
Type | Description |
---|---|
str |
The URL of the object in the store. |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_url_for_key(
self, key: Tuple[str, ...], protocol: Optional[str] = None
) -> str:
"""Get the URL of an object in the store.
Args:
key: object key identifier.
protocol: optional protocol to use instead of the store protocol.
Returns:
The URL of the object in the store.
"""
filepath = self._build_object_path(key)
if not protocol and not io_utils.is_remote(filepath):
protocol = "file:"
if protocol:
filepath = filepath.replace(self.proto, f"{protocol}//", 1)
return filepath
list_keys(self, prefix=())
List the keys of all objects identified by a partial key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prefix |
Tuple[str, ...] |
partial object key identifier. |
() |
Returns:
Type | Description |
---|---|
List[Tuple[str, ...]] |
List of keys identifying all objects present in the store that match the input partial key. |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
"""List the keys of all objects identified by a partial key.
Args:
prefix: partial object key identifier.
Returns:
List of keys identifying all objects present in the store that
match the input partial key.
"""
key_list = []
list_path = self._build_object_path(prefix, is_prefix=True)
root_path = self._build_object_path(tuple(), is_prefix=True)
for root, dirs, files in fileio.walk(list_path):
for file_ in files:
filepath = os.path.relpath(
os.path.join(str(root), str(file_)), root_path
)
if self.filepath_prefix and not filepath.startswith(
self.filepath_prefix
):
continue
elif self.filepath_suffix and not filepath.endswith(
self.filepath_suffix
):
continue
key = self._convert_filepath_to_key(filepath)
if key and not self.is_ignored_key(key):
key_list.append(key)
return key_list
remove_key(self, key)
Delete an object from the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Tuple[str, ...] |
object key identifier. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the object existed in the store and was removed, otherwise False. |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def remove_key(self, key: Tuple[str, ...]) -> bool:
"""Delete an object from the store.
Args:
key: object key identifier.
Returns:
True if the object existed in the store and was removed, otherwise
False.
"""
filepath: str = self._build_object_path(key)
if fileio.exists(filepath):
fileio.remove(filepath)
if not io_utils.is_remote(filepath):
parent_dir = str(Path(filepath).parent)
self.rrmdir(self.root_path, str(parent_dir))
return True
return False
rrmdir(start_path, end_path)
staticmethod
Recursively removes empty dirs between start_path and end_path inclusive.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start_path |
str |
Directory to use as a starting point. |
required |
end_path |
str |
Directory to use as a destination point. |
required |
Source code in zenml/integrations/great_expectations/ge_store_backend.py
@staticmethod
def rrmdir(start_path: str, end_path: str) -> None:
"""Recursively removes empty dirs between start_path and end_path inclusive.
Args:
start_path: Directory to use as a starting point.
end_path: Directory to use as a destination point.
"""
while not os.listdir(end_path) and start_path != end_path:
os.rmdir(end_path)
end_path = os.path.dirname(end_path)
materializers
special
Materializers for Great Expectation serializable objects.
ge_materializer
Implementation of the Great Expectations materializers.
GreatExpectationsMaterializer (BaseMaterializer)
Materializer to read/write Great Expectation objects.
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
class GreatExpectationsMaterializer(BaseMaterializer):
"""Materializer to read/write Great Expectation objects."""
ASSOCIATED_TYPES = (
ExpectationSuite,
CheckpointResult,
)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
@staticmethod
def preprocess_checkpoint_result_dict(
artifact_dict: Dict[str, Any]
) -> None:
"""Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.
The GE CheckpointResult object is not fully de-serializable
due to some missing code in the GE codebase. We need to compensate
for this by manually converting some of the attributes to
their correct data types.
Args:
artifact_dict: A dict containing the GE checkpoint result.
"""
def preprocess_run_result(key: str, value: Any) -> Any:
if key == "validation_result":
return ExpectationSuiteValidationResult(**value)
return value
artifact_dict["checkpoint_config"] = CheckpointConfig(
**artifact_dict["checkpoint_config"]
)
validation_dict = {}
for result_ident, results in artifact_dict["run_results"].items():
validation_ident = (
ValidationResultIdentifier.from_fixed_length_tuple(
result_ident.split("::")[1].split("/")
)
)
validation_results = {
result_name: preprocess_run_result(result_name, result)
for result_name, result in results.items()
}
validation_dict[validation_ident] = validation_results
artifact_dict["run_results"] = validation_dict
def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
"""Reads and returns a Great Expectations object.
Args:
data_type: The type of the data to read.
Returns:
A loaded Great Expectations object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = yaml_utils.read_json(filepath)
data_type = import_class_by_path(artifact_dict.pop("data_type"))
if data_type is CheckpointResult:
self.preprocess_checkpoint_result_dict(artifact_dict)
return data_type(**artifact_dict)
def handle_return(self, obj: SerializableDictDot) -> None:
"""Writes a Great Expectations object.
Args:
obj: A Great Expectations object.
"""
super().handle_return(obj)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = obj.to_json_dict()
artifact_type = type(obj)
artifact_dict[
"data_type"
] = f"{artifact_type.__module__}.{artifact_type.__name__}"
yaml_utils.write_json(filepath, artifact_dict)
handle_input(self, data_type)
Reads and returns a Great Expectations object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
SerializableDictDot |
A loaded Great Expectations object. |
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
"""Reads and returns a Great Expectations object.
Args:
data_type: The type of the data to read.
Returns:
A loaded Great Expectations object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = yaml_utils.read_json(filepath)
data_type = import_class_by_path(artifact_dict.pop("data_type"))
if data_type is CheckpointResult:
self.preprocess_checkpoint_result_dict(artifact_dict)
return data_type(**artifact_dict)
handle_return(self, obj)
Writes a Great Expectations object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
SerializableDictDot |
A Great Expectations object. |
required |
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_return(self, obj: SerializableDictDot) -> None:
"""Writes a Great Expectations object.
Args:
obj: A Great Expectations object.
"""
super().handle_return(obj)
filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
artifact_dict = obj.to_json_dict()
artifact_type = type(obj)
artifact_dict[
"data_type"
] = f"{artifact_type.__module__}.{artifact_type.__name__}"
yaml_utils.write_json(filepath, artifact_dict)
preprocess_checkpoint_result_dict(artifact_dict)
staticmethod
Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.
The GE CheckpointResult object is not fully de-serializable due to some missing code in the GE codebase. We need to compensate for this by manually converting some of the attributes to their correct data types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_dict |
Dict[str, Any] |
A dict containing the GE checkpoint result. |
required |
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
@staticmethod
def preprocess_checkpoint_result_dict(
artifact_dict: Dict[str, Any]
) -> None:
"""Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.
The GE CheckpointResult object is not fully de-serializable
due to some missing code in the GE codebase. We need to compensate
for this by manually converting some of the attributes to
their correct data types.
Args:
artifact_dict: A dict containing the GE checkpoint result.
"""
def preprocess_run_result(key: str, value: Any) -> Any:
if key == "validation_result":
return ExpectationSuiteValidationResult(**value)
return value
artifact_dict["checkpoint_config"] = CheckpointConfig(
**artifact_dict["checkpoint_config"]
)
validation_dict = {}
for result_ident, results in artifact_dict["run_results"].items():
validation_ident = (
ValidationResultIdentifier.from_fixed_length_tuple(
result_ident.split("::")[1].split("/")
)
)
validation_results = {
result_name: preprocess_run_result(result_name, result)
for result_name, result in results.items()
}
validation_dict[validation_ident] = validation_results
artifact_dict["run_results"] = validation_dict
steps
special
Great Expectations data profiling and validation standard steps.
ge_profiler
Great Expectations data profiling standard step.
GreatExpectationsProfilerParameters (BaseParameters)
pydantic-model
Parameters class for a Great Expectations profiler step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to create or update. |
data_asset_name |
Optional[str] |
The name of the data asset to run the expectation suite on. |
profiler_kwargs |
Optional[Dict[str, Any]] |
A dictionary of keyword arguments to pass to the profiler. |
overwrite_existing_suite |
bool |
Whether to overwrite an existing expectation suite. |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerParameters(BaseParameters):
"""Parameters class for a Great Expectations profiler step.
Attributes:
expectation_suite_name: The name of the expectation suite to create
or update.
data_asset_name: The name of the data asset to run the expectation suite on.
profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
overwrite_existing_suite: Whether to overwrite an existing expectation suite.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
overwrite_existing_suite: bool = True
GreatExpectationsProfilerStep (BaseStep)
Standard Great Expectations profiling step implementation.
Use this standard Great Expectations profiling step to build an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerStep(BaseStep):
"""Standard Great Expectations profiling step implementation.
Use this standard Great Expectations profiling step to build an Expectation
Suite automatically by running a UserConfigurableProfiler on an input
dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
params: GreatExpectationsProfilerParameters,
) -> ExpectationSuite:
"""Standard Great Expectations data profiling step entrypoint.
Args:
dataset: The dataset from which the expectation suite will be inferred.
params: The parameters for the step.
Returns:
The generated Great Expectations suite.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
return data_validator.data_profiling(
dataset,
expectation_suite_name=params.expectation_suite_name,
data_asset_name=params.data_asset_name,
profiler_kwargs=params.profiler_kwargs,
overwrite_existing_suite=params.overwrite_existing_suite,
)
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Parameters class for a Great Expectations profiler step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to create or update. |
data_asset_name |
Optional[str] |
The name of the data asset to run the expectation suite on. |
profiler_kwargs |
Optional[Dict[str, Any]] |
A dictionary of keyword arguments to pass to the profiler. |
overwrite_existing_suite |
bool |
Whether to overwrite an existing expectation suite. |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerParameters(BaseParameters):
"""Parameters class for a Great Expectations profiler step.
Attributes:
expectation_suite_name: The name of the expectation suite to create
or update.
data_asset_name: The name of the data asset to run the expectation suite on.
profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
overwrite_existing_suite: Whether to overwrite an existing expectation suite.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
overwrite_existing_suite: bool = True
entrypoint(self, dataset, params)
Standard Great Expectations data profiling step entrypoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset from which the expectation suite will be inferred. |
required |
params |
GreatExpectationsProfilerParameters |
The parameters for the step. |
required |
Returns:
Type | Description |
---|---|
ExpectationSuite |
The generated Great Expectations suite. |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
params: GreatExpectationsProfilerParameters,
) -> ExpectationSuite:
"""Standard Great Expectations data profiling step entrypoint.
Args:
dataset: The dataset from which the expectation suite will be inferred.
params: The parameters for the step.
Returns:
The generated Great Expectations suite.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
return data_validator.data_profiling(
dataset,
expectation_suite_name=params.expectation_suite_name,
data_asset_name=params.data_asset_name,
profiler_kwargs=params.profiler_kwargs,
overwrite_existing_suite=params.overwrite_existing_suite,
)
great_expectations_profiler_step(step_name, params)
Shortcut function to create a new instance of the GreatExpectationsProfilerStep step.
The returned GreatExpectationsProfilerStep can be used in a pipeline to infer data validation rules from an input pd.DataFrame dataset and return them as an Expectation Suite. The Expectation Suite is also persisted in the Great Expectations expectation store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
GreatExpectationsProfilerParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a GreatExpectationsProfilerStep step instance |
Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def great_expectations_profiler_step(
step_name: str,
params: GreatExpectationsProfilerParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the GreatExpectationsProfilerStep step.
The returned GreatExpectationsProfilerStep can be used in a pipeline to
infer data validation rules from an input pd.DataFrame dataset and return
them as an Expectation Suite. The Expectation Suite is also persisted in the
Great Expectations expectation store.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a GreatExpectationsProfilerStep step instance
"""
return clone_step(GreatExpectationsProfilerStep, step_name)(params=params)
ge_validator
Great Expectations data validation standard step.
GreatExpectationsValidatorParameters (BaseParameters)
pydantic-model
Parameters class for a Great Expectations checkpoint step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to use to validate the dataset. |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
action_list |
Optional[List[Dict[str, Any]]] |
A list of additional Great Expectations actions to run after the validation check. |
exit_on_error |
bool |
Set this flag to raise an error and exit the pipeline early if the validation fails. |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorParameters(BaseParameters):
"""Parameters class for a Great Expectations checkpoint step.
Attributes:
expectation_suite_name: The name of the expectation suite to use to
validate the dataset.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run
after the validation check.
exit_on_error: Set this flag to raise an error and exit the pipeline
early if the validation fails.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
action_list: Optional[List[Dict[str, Any]]] = None
exit_on_error: bool = False
GreatExpectationsValidatorStep (BaseStep)
Standard Great Expectations data validation step implementation.
Use this standard Great Expectations data validation step to run an existing Expectation Suite on an input dataset as covered in the official GE documentation.
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorStep(BaseStep):
"""Standard Great Expectations data validation step implementation.
Use this standard Great Expectations data validation step to run an
existing Expectation Suite on an input dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
"""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
condition: bool,
params: GreatExpectationsValidatorParameters,
) -> CheckpointResult:
"""Standard Great Expectations data validation step entrypoint.
Args:
dataset: The dataset to run the expectation suite on.
condition: This dummy argument can be used as a condition to enforce
that this step is only run after another step has completed. This
is useful for example if the Expectation Suite used to validate
the data is computed in a `GreatExpectationsProfilerStep` that
is part of the same pipeline.
params: The parameters for the step.
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
RuntimeError: if the step is configured to exit on error and the
data validation failed.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
results = data_validator.data_validation(
dataset,
expectation_suite_name=params.expectation_suite_name,
data_asset_name=params.data_asset_name,
action_list=params.action_list,
)
if params.exit_on_error and not results.success:
raise RuntimeError(
"The Great Expectations validation failed. Check "
"the logs or the Great Expectations data docs for more "
"information."
)
return results
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Parameters class for a Great Expectations checkpoint step.
Attributes:
Name | Type | Description |
---|---|---|
expectation_suite_name |
str |
The name of the expectation suite to use to validate the dataset. |
data_asset_name |
Optional[str] |
The name of the data asset to use to identify the dataset in the Great Expectations docs. |
action_list |
Optional[List[Dict[str, Any]]] |
A list of additional Great Expectations actions to run after the validation check. |
exit_on_error |
bool |
Set this flag to raise an error and exit the pipeline early if the validation fails. |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorParameters(BaseParameters):
"""Parameters class for a Great Expectations checkpoint step.
Attributes:
expectation_suite_name: The name of the expectation suite to use to
validate the dataset.
data_asset_name: The name of the data asset to use to identify the
dataset in the Great Expectations docs.
action_list: A list of additional Great Expectations actions to run
after the validation check.
exit_on_error: Set this flag to raise an error and exit the pipeline
early if the validation fails.
"""
expectation_suite_name: str
data_asset_name: Optional[str] = None
action_list: Optional[List[Dict[str, Any]]] = None
exit_on_error: bool = False
entrypoint(self, dataset, condition, params)
Standard Great Expectations data validation step entrypoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
The dataset to run the expectation suite on. |
required |
condition |
bool |
This dummy argument can be used as a condition to enforce
that this step is only run after another step has completed. This
is useful for example if the Expectation Suite used to validate
the data is computed in a |
required |
params |
GreatExpectationsValidatorParameters |
The parameters for the step. |
required |
Returns:
Type | Description |
---|---|
CheckpointResult |
The Great Expectations validation (checkpoint) result. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the step is configured to exit on error and the data validation failed. |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
condition: bool,
params: GreatExpectationsValidatorParameters,
) -> CheckpointResult:
"""Standard Great Expectations data validation step entrypoint.
Args:
dataset: The dataset to run the expectation suite on.
condition: This dummy argument can be used as a condition to enforce
that this step is only run after another step has completed. This
is useful for example if the Expectation Suite used to validate
the data is computed in a `GreatExpectationsProfilerStep` that
is part of the same pipeline.
params: The parameters for the step.
Returns:
The Great Expectations validation (checkpoint) result.
Raises:
RuntimeError: if the step is configured to exit on error and the
data validation failed.
"""
data_validator = (
GreatExpectationsDataValidator.get_active_data_validator()
)
results = data_validator.data_validation(
dataset,
expectation_suite_name=params.expectation_suite_name,
data_asset_name=params.data_asset_name,
action_list=params.action_list,
)
if params.exit_on_error and not results.success:
raise RuntimeError(
"The Great Expectations validation failed. Check "
"the logs or the Great Expectations data docs for more "
"information."
)
return results
great_expectations_validator_step(step_name, params)
Shortcut function to create a new instance of the GreatExpectationsValidatorStep step.
The returned GreatExpectationsValidatorStep can be used in a pipeline to validate an input pd.DataFrame dataset and return the result as a Great Expectations CheckpointResult object. The validation results are also persisted in the Great Expectations validation store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
GreatExpectationsValidatorParameters |
The parameters for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a GreatExpectationsProfilerStep step instance |
Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def great_expectations_validator_step(
step_name: str,
params: GreatExpectationsValidatorParameters,
) -> BaseStep:
"""Shortcut function to create a new instance of the GreatExpectationsValidatorStep step.
The returned GreatExpectationsValidatorStep can be used in a pipeline to
validate an input pd.DataFrame dataset and return the result as a Great
Expectations CheckpointResult object. The validation results are also
persisted in the Great Expectations validation store.
Args:
step_name: The name of the step
params: The parameters for the step
Returns:
a GreatExpectationsProfilerStep step instance
"""
return clone_step(GreatExpectationsValidatorStep, step_name)(params=params)
utils
Great Expectations data profiling standard step.
create_batch_request(context, dataset, data_asset_name)
Create a temporary runtime GE batch request from a dataset step artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
BaseDataContext |
Great Expectations data context. |
required |
dataset |
DataFrame |
Input dataset. |
required |
data_asset_name |
Optional[str] |
Optional custom name for the data asset. |
required |
Returns:
Type | Description |
---|---|
RuntimeBatchRequest |
A Great Expectations runtime batch request. |
Source code in zenml/integrations/great_expectations/utils.py
def create_batch_request(
context: BaseDataContext,
dataset: pd.DataFrame,
data_asset_name: Optional[str],
) -> RuntimeBatchRequest:
"""Create a temporary runtime GE batch request from a dataset step artifact.
Args:
context: Great Expectations data context.
dataset: Input dataset.
data_asset_name: Optional custom name for the data asset.
Returns:
A Great Expectations runtime batch request.
"""
try:
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
except KeyError:
# if not running inside a pipeline step, use random values
pipeline_name = f"pipeline_{random_str(5)}"
run_id = f"pipeline_{random_str(5)}"
step_name = f"step_{random_str(5)}"
datasource_name = f"{run_id}_{step_name}"
data_connector_name = datasource_name
data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}"
batch_identifier = "default"
datasource_config = {
"name": datasource_name,
"class_name": "Datasource",
"module_name": "great_expectations.datasource",
"execution_engine": {
"module_name": "great_expectations.execution_engine",
"class_name": "PandasExecutionEngine",
},
"data_connectors": {
data_connector_name: {
"class_name": "RuntimeDataConnector",
"batch_identifiers": [batch_identifier],
},
},
}
context.add_datasource(**datasource_config)
batch_request = RuntimeBatchRequest(
datasource_name=datasource_name,
data_connector_name=data_connector_name,
data_asset_name=data_asset_name,
runtime_parameters={"batch_data": dataset},
batch_identifiers={batch_identifier: batch_identifier},
)
return batch_request
visualizers
special
Great Expectations visualizers for expectation suites and validation results.
ge_visualizer
Great Expectations visualizers for expectation suites and validation results.
GreatExpectationsVisualizer (BaseVisualizer)
The implementation of a Great Expectations Visualizer.
Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
class GreatExpectationsVisualizer(BaseVisualizer):
"""The implementation of a Great Expectations Visualizer."""
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize a Great Expectations resource.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but Great Expectations data analysis artifacts
if (
artifact_view.type == DataAnalysisArtifact.__name__
and artifact_view.data_type.startswith("great_expectations.")
):
artifact = artifact_view.read()
if isinstance(artifact, CheckpointResult):
result = cast(CheckpointResult, artifact)
identifier = next(iter(result.run_results.keys()))
else:
suite = cast(ExpectationSuite, artifact)
identifier = ExpectationSuiteIdentifier(
suite.expectation_suite_name
)
context = GreatExpectationsDataValidator.get_data_context()
context.open_data_docs(identifier)
visualize(self, object, *args, **kwargs)
Method to visualize a Great Expectations resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
"""Method to visualize a Great Expectations resource.
Args:
object: StepView fetched from run.get_step().
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for artifact_view in object.outputs.values():
# filter out anything but Great Expectations data analysis artifacts
if (
artifact_view.type == DataAnalysisArtifact.__name__
and artifact_view.data_type.startswith("great_expectations.")
):
artifact = artifact_view.read()
if isinstance(artifact, CheckpointResult):
result = cast(CheckpointResult, artifact)
identifier = next(iter(result.run_results.keys()))
else:
suite = cast(ExpectationSuite, artifact)
identifier = ExpectationSuiteIdentifier(
suite.expectation_suite_name
)
context = GreatExpectationsDataValidator.get_data_context()
context.open_data_docs(identifier)
huggingface
special
Initialization of the Huggingface integration.
HuggingfaceIntegration (Integration)
Definition of Huggingface integration for ZenML.
Source code in zenml/integrations/huggingface/__init__.py
class HuggingfaceIntegration(Integration):
"""Definition of Huggingface integration for ZenML."""
NAME = HUGGINGFACE
REQUIREMENTS = ["transformers", "datasets"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.huggingface import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/huggingface/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.huggingface import materializers # noqa
materializers
special
Initialization of Huggingface materializers.
huggingface_datasets_materializer
Implementation of the Huggingface datasets materializer.
HFDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from huggingface datasets.
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
class HFDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from huggingface datasets."""
ASSOCIATED_TYPES = (Dataset, DatasetDict)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads Dataset.
Args:
data_type: The type of the dataset to read.
Returns:
The dataset read from the specified dir.
"""
super().handle_input(data_type)
return load_from_disk(
os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
def handle_return(self, ds: Type[Any]) -> None:
"""Writes a Dataset to the specified dir.
Args:
ds: The Dataset to write.
"""
super().handle_return(ds)
temp_dir = TemporaryDirectory()
ds.save_to_disk(temp_dir.name)
io_utils.copy_dir(
temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
handle_input(self, data_type)
Reads Dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the dataset to read. |
required |
Returns:
Type | Description |
---|---|
Dataset |
The dataset read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads Dataset.
Args:
data_type: The type of the dataset to read.
Returns:
The dataset read from the specified dir.
"""
super().handle_input(data_type)
return load_from_disk(
os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
handle_return(self, ds)
Writes a Dataset to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ds |
Type[Any] |
The Dataset to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_return(self, ds: Type[Any]) -> None:
"""Writes a Dataset to the specified dir.
Args:
ds: The Dataset to write.
"""
super().handle_return(ds)
temp_dir = TemporaryDirectory()
ds.save_to_disk(temp_dir.name)
io_utils.copy_dir(
temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
)
huggingface_pt_model_materializer
Implementation of the Huggingface PyTorch model materializer.
HFPTModelMaterializer (BaseMaterializer)
Materializer to read torch model to and from huggingface pretrained model.
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
class HFPTModelMaterializer(BaseMaterializer):
"""Materializer to read torch model to and from huggingface pretrained model."""
ASSOCIATED_TYPES = (PreTrainedModel,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
architecture = config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The Torch Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
)
handle_input(self, data_type)
Reads HFModel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to read. |
required |
Returns:
Type | Description |
---|---|
PreTrainedModel |
The model read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
architecture = config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
)
handle_return(self, model)
Writes a Model to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Type[Any] |
The Torch Model to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The Torch Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
)
huggingface_tf_model_materializer
Implementation of the Huggingface TF model materializer.
HFTFModelMaterializer (BaseMaterializer)
Materializer to read Tensorflow model to and from huggingface pretrained model.
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
class HFTFModelMaterializer(BaseMaterializer):
"""Materializer to read Tensorflow model to and from huggingface pretrained model."""
ASSOCIATED_TYPES = (TFPreTrainedModel,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
architecture = "TF" + config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The TF Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
)
handle_input(self, data_type)
Reads HFModel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to read. |
required |
Returns:
Type | Description |
---|---|
TFPreTrainedModel |
The model read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
"""Reads HFModel.
Args:
data_type: The type of the model to read.
Returns:
The model read from the specified dir.
"""
super().handle_input(data_type)
config = AutoConfig.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
architecture = "TF" + config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
)
handle_return(self, model)
Writes a Model to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Type[Any] |
The TF Model to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
"""Writes a Model to the specified dir.
Args:
model: The TF Model to write.
"""
super().handle_return(model)
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
)
huggingface_tokenizer_materializer
Implementation of the Huggingface tokenizer materializer.
HFTokenizerMaterializer (BaseMaterializer)
Materializer to read tokenizer to and from huggingface tokenizer.
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
class HFTokenizerMaterializer(BaseMaterializer):
"""Materializer to read tokenizer to and from huggingface tokenizer."""
ASSOCIATED_TYPES = (PreTrainedTokenizerBase,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
"""Reads Tokenizer.
Args:
data_type: The type of the tokenizer to read.
Returns:
The tokenizer read from the specified dir.
"""
super().handle_input(data_type)
return AutoTokenizer.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
)
def handle_return(self, tokenizer: Type[Any]) -> None:
"""Writes a Tokenizer to the specified dir.
Args:
tokenizer: The HFTokenizer to write.
"""
super().handle_return(tokenizer)
temp_dir = TemporaryDirectory()
tokenizer.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
)
handle_input(self, data_type)
Reads Tokenizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the tokenizer to read. |
required |
Returns:
Type | Description |
---|---|
PreTrainedTokenizerBase |
The tokenizer read from the specified dir. |
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
"""Reads Tokenizer.
Args:
data_type: The type of the tokenizer to read.
Returns:
The tokenizer read from the specified dir.
"""
super().handle_input(data_type)
return AutoTokenizer.from_pretrained(
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
)
handle_return(self, tokenizer)
Writes a Tokenizer to the specified dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tokenizer |
Type[Any] |
The HFTokenizer to write. |
required |
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_return(self, tokenizer: Type[Any]) -> None:
"""Writes a Tokenizer to the specified dir.
Args:
tokenizer: The HFTokenizer to write.
"""
super().handle_return(tokenizer)
temp_dir = TemporaryDirectory()
tokenizer.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
)
integration
Base and meta classes for ZenML integrations.
Integration
Base class for integration in ZenML.
Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
"""Base class for integration in ZenML."""
NAME = "base_integration"
REQUIREMENTS: List[str] = []
SYSTEM_REQUIREMENTS: Dict[str, str] = {}
@classmethod
def check_installation(cls) -> bool:
"""Method to check whether the required packages are installed.
Returns:
True if all required packages are installed, False otherwise.
"""
try:
for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
result = shutil.which(command)
if result is None:
logger.debug(
"Unable to find the required packages for %s on your "
"system. Please install the packages on your system "
"and try again.",
requirement,
)
return False
for r in cls.REQUIREMENTS:
pkg_resources.get_distribution(r)
logger.debug(
f"Integration {cls.NAME} is installed correctly with "
f"requirements {cls.REQUIREMENTS}."
)
return True
except pkg_resources.DistributionNotFound as e:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
logger.debug(
f"VersionConflict error when loading installation {cls.NAME}: "
f"{str(e)}"
)
return False
@classmethod
def activate(cls) -> None:
"""Abstract method to activate the integration."""
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Abstract method to declare new stack component flavors."""
activate()
classmethod
Abstract method to activate the integration.
Source code in zenml/integrations/integration.py
@classmethod
def activate(cls) -> None:
"""Abstract method to activate the integration."""
check_installation()
classmethod
Method to check whether the required packages are installed.
Returns:
Type | Description |
---|---|
bool |
True if all required packages are installed, False otherwise. |
Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
"""Method to check whether the required packages are installed.
Returns:
True if all required packages are installed, False otherwise.
"""
try:
for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
result = shutil.which(command)
if result is None:
logger.debug(
"Unable to find the required packages for %s on your "
"system. Please install the packages on your system "
"and try again.",
requirement,
)
return False
for r in cls.REQUIREMENTS:
pkg_resources.get_distribution(r)
logger.debug(
f"Integration {cls.NAME} is installed correctly with "
f"requirements {cls.REQUIREMENTS}."
)
return True
except pkg_resources.DistributionNotFound as e:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
logger.debug(
f"VersionConflict error when loading installation {cls.NAME}: "
f"{str(e)}"
)
return False
flavors()
classmethod
Abstract method to declare new stack component flavors.
Source code in zenml/integrations/integration.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Abstract method to declare new stack component flavors."""
IntegrationMeta (type)
Metaclass responsible for registering different Integration subclasses.
Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
"""Metaclass responsible for registering different Integration subclasses."""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
"""Hook into creation of an Integration class.
Args:
name: The name of the class being created.
bases: The base classes of the class being created.
dct: The dictionary of attributes of the class being created.
Returns:
The newly created class.
"""
cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
if name != "Integration":
integration_registry.register_integration(cls.NAME, cls)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Hook into creation of an Integration class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class being created. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class being created. |
required |
dct |
Dict[str, Any] |
The dictionary of attributes of the class being created. |
required |
Returns:
Type | Description |
---|---|
IntegrationMeta |
The newly created class. |
Source code in zenml/integrations/integration.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
"""Hook into creation of an Integration class.
Args:
name: The name of the class being created.
bases: The base classes of the class being created.
dct: The dictionary of attributes of the class being created.
Returns:
The newly created class.
"""
cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
if name != "Integration":
integration_registry.register_integration(cls.NAME, cls)
return cls
kserve
special
Initialization of the KServe integration for ZenML.
The KServe integration allows you to use the KServe model serving platform to implement continuous model deployment.
KServeIntegration (Integration)
Definition of KServe integration for ZenML.
Source code in zenml/integrations/kserve/__init__.py
class KServeIntegration(Integration):
"""Definition of KServe integration for ZenML."""
NAME = KSERVE
REQUIREMENTS = [
"kserve==0.9.0",
"torch-model-archiver",
]
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.kserve import model_deployers # noqa
from zenml.integrations.kserve import secret_schemas # noqa
from zenml.integrations.kserve import services # noqa
from zenml.integrations.kserve import steps # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for KServe.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.kserve.flavors import KServeModelDeployerFlavor
return [KServeModelDeployerFlavor]
activate()
classmethod
Activate the Seldon Core integration.
Source code in zenml/integrations/kserve/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.kserve import model_deployers # noqa
from zenml.integrations.kserve import secret_schemas # noqa
from zenml.integrations.kserve import services # noqa
from zenml.integrations.kserve import steps # noqa
flavors()
classmethod
Declare the stack component flavors for KServe.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/kserve/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for KServe.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.kserve.flavors import KServeModelDeployerFlavor
return [KServeModelDeployerFlavor]
constants
KServe constants.
custom_deployer
special
Initialization of ZenML custom deployer.
zenml_custom_model
Implements a custom model for the Kserve integration.
ZenMLCustomModel (Model)
Custom model class for ZenML and Kserve.
This class is used to implement a custom model for the Kserve integration, which is used as the main entry point for custom code execution.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
The name of the model. |
|
model_uri |
The URI of the model. |
|
predict_func |
The predict function of the model. |
Source code in zenml/integrations/kserve/custom_deployer/zenml_custom_model.py
class ZenMLCustomModel(kserve.Model): # type: ignore[misc]
"""Custom model class for ZenML and Kserve.
This class is used to implement a custom model for the Kserve integration,
which is used as the main entry point for custom code execution.
Attributes:
model_name: The name of the model.
model_uri: The URI of the model.
predict_func: The predict function of the model.
"""
def __init__(
self,
model_name: str,
model_uri: str,
predict_func: str,
):
"""Initializes a ZenMLCustomModel object.
Args:
model_name: The name of the model.
model_uri: The URI of the model.
predict_func: The predict function of the model.
"""
super().__init__(model_name)
self.name = model_name
self.model_uri = model_uri
self.predict_func = import_class_by_path(predict_func)
self.model = None
self.ready = False
def load(self) -> bool:
"""Load the model.
This function loads the model into memory and sets the ready flag to True.
The model is loaded using the materializer, by saving the information of
the artifact to a YAML file in the same path as the model artifacts at
the preparing time and loading it again at the prediction time by
the materializer.
Returns:
True if the model was loaded successfully, False otherwise.
"""
try:
from zenml.utils.materializer_utils import load_model_from_metadata
self.model = load_model_from_metadata(self.model_uri)
except Exception as e:
logger.error("Failed to load model: {}".format(e))
return False
self.ready = True
return self.ready
def predict(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Predict the given request.
The main predict function of the model. This function is called by the
KServe server when a request is received. Then inside this function,
the user-defined predict function is called.
Args:
request: The request to predict in a dictionary. e.g. {"instances": []}
Returns:
The prediction dictionary.
Raises:
RuntimeError: If function could not be called.
NotImplementedError: If the model is not ready.
TypeError: If the request is not a dictionary.
"""
if self.predict_func is not None:
try:
prediction = {
"predictions": self.predict_func(
self.model, request["instances"]
)
}
except RuntimeError as err:
raise RuntimeError("Failed to predict: {}".format(err))
if isinstance(prediction, dict):
return prediction
else:
raise TypeError(
f"Prediction is not a dictionary. Expecting a dictionary but got {type(prediction)}"
)
else:
raise NotImplementedError("Predict function is not implemented")
__init__(self, model_name, model_uri, predict_func)
special
Initializes a ZenMLCustomModel object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
The name of the model. |
required |
model_uri |
str |
The URI of the model. |
required |
predict_func |
str |
The predict function of the model. |
required |
Source code in zenml/integrations/kserve/custom_deployer/zenml_custom_model.py
def __init__(
self,
model_name: str,
model_uri: str,
predict_func: str,
):
"""Initializes a ZenMLCustomModel object.
Args:
model_name: The name of the model.
model_uri: The URI of the model.
predict_func: The predict function of the model.
"""
super().__init__(model_name)
self.name = model_name
self.model_uri = model_uri
self.predict_func = import_class_by_path(predict_func)
self.model = None
self.ready = False
load(self)
Load the model.
This function loads the model into memory and sets the ready flag to True.
The model is loaded using the materializer, by saving the information of the artifact to a YAML file in the same path as the model artifacts at the preparing time and loading it again at the prediction time by the materializer.
Returns:
Type | Description |
---|---|
bool |
True if the model was loaded successfully, False otherwise. |
Source code in zenml/integrations/kserve/custom_deployer/zenml_custom_model.py
def load(self) -> bool:
"""Load the model.
This function loads the model into memory and sets the ready flag to True.
The model is loaded using the materializer, by saving the information of
the artifact to a YAML file in the same path as the model artifacts at
the preparing time and loading it again at the prediction time by
the materializer.
Returns:
True if the model was loaded successfully, False otherwise.
"""
try:
from zenml.utils.materializer_utils import load_model_from_metadata
self.model = load_model_from_metadata(self.model_uri)
except Exception as e:
logger.error("Failed to load model: {}".format(e))
return False
self.ready = True
return self.ready
predict(self, request)
Predict the given request.
The main predict function of the model. This function is called by the KServe server when a request is received. Then inside this function, the user-defined predict function is called.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Dict[str, Any] |
The request to predict in a dictionary. e.g. {"instances": []} |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The prediction dictionary. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If function could not be called. |
NotImplementedError |
If the model is not ready. |
TypeError |
If the request is not a dictionary. |
Source code in zenml/integrations/kserve/custom_deployer/zenml_custom_model.py
def predict(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Predict the given request.
The main predict function of the model. This function is called by the
KServe server when a request is received. Then inside this function,
the user-defined predict function is called.
Args:
request: The request to predict in a dictionary. e.g. {"instances": []}
Returns:
The prediction dictionary.
Raises:
RuntimeError: If function could not be called.
NotImplementedError: If the model is not ready.
TypeError: If the request is not a dictionary.
"""
if self.predict_func is not None:
try:
prediction = {
"predictions": self.predict_func(
self.model, request["instances"]
)
}
except RuntimeError as err:
raise RuntimeError("Failed to predict: {}".format(err))
if isinstance(prediction, dict):
return prediction
else:
raise TypeError(
f"Prediction is not a dictionary. Expecting a dictionary but got {type(prediction)}"
)
else:
raise NotImplementedError("Predict function is not implemented")
flavors
special
KServe integration flavors.
kserve_model_deployer_flavor
KServe model deployer flavor.
KServeModelDeployerConfig (BaseModelDeployerConfig)
pydantic-model
Configuration for the KServeModelDeployer.
Attributes:
Name | Type | Description |
---|---|---|
kubernetes_context |
Optional[str] |
the Kubernetes context to use to contact the remote KServe installation. If not specified, the current configuration is used. Depending on where the KServe model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod). |
kubernetes_namespace |
Optional[str] |
the Kubernetes namespace where the KServe inference service CRDs are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the KServe model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod). |
base_url |
str |
the base URL of the Kubernetes ingress used to expose the KServe inference services. |
secret |
Optional[str] |
the name of the secret containing the credentials for the KServe inference services. |
Source code in zenml/integrations/kserve/flavors/kserve_model_deployer_flavor.py
class KServeModelDeployerConfig(BaseModelDeployerConfig):
"""Configuration for the KServeModelDeployer.
Attributes:
kubernetes_context: the Kubernetes context to use to contact the remote
KServe installation. If not specified, the current
configuration is used. Depending on where the KServe model deployer
is being used, this can be either a locally active context or an
in-cluster Kubernetes configuration (if running inside a pod).
kubernetes_namespace: the Kubernetes namespace where the KServe
inference service CRDs are provisioned and managed by ZenML. If not
specified, the namespace set in the current configuration is used.
Depending on where the KServe model deployer is being used, this can
be either the current namespace configured in the locally active
context or the namespace in the context of which the pod is running
(if running inside a pod).
base_url: the base URL of the Kubernetes ingress used to expose the
KServe inference services.
secret: the name of the secret containing the credentials for the
KServe inference services.
"""
kubernetes_context: Optional[str]
kubernetes_namespace: Optional[str]
base_url: str # TODO: unused?
secret: Optional[str]
custom_domain: Optional[str] # TODO: unused?
KServeModelDeployerFlavor (BaseModelDeployerFlavor)
Flavor for the KServe model deployer.
Source code in zenml/integrations/kserve/flavors/kserve_model_deployer_flavor.py
class KServeModelDeployerFlavor(BaseModelDeployerFlavor):
"""Flavor for the KServe model deployer."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return KSERVE_MODEL_DEPLOYER_FLAVOR
@property
def config_class(self) -> Type[KServeModelDeployerConfig]:
"""Returns `KServeModelDeployerConfig` config class.
Returns:
The config class.
"""
return KServeModelDeployerConfig
@property
def implementation_class(self) -> Type["KServeModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.kserve.model_deployers import (
KServeModelDeployer,
)
return KServeModelDeployer
config_class: Type[zenml.integrations.kserve.flavors.kserve_model_deployer_flavor.KServeModelDeployerConfig]
property
readonly
Returns KServeModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.kserve.flavors.kserve_model_deployer_flavor.KServeModelDeployerConfig] |
The config class. |
implementation_class: Type[KServeModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[KServeModelDeployer] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
model_deployers
special
Initialization of the KServe Model Deployer.
kserve_model_deployer
Implementation of the KServe Model Deployer.
KServeModelDeployer (BaseModelDeployer)
KServe model deployer stack component implementation.
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
class KServeModelDeployer(BaseModelDeployer):
"""KServe model deployer stack component implementation."""
_client: Optional[KServeClient] = None
@property
def config(self) -> KServeModelDeployerConfig:
"""Returns the `KServeModelDeployerConfig` config.
Returns:
The configuration.
"""
return cast(KServeModelDeployerConfig, self._config)
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "KServeDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information on the model server.
Args:
service_instance: KServe deployment service object
Returns:
A dictionary containing the model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"PREDICTION_HOSTNAME": service_instance.prediction_hostname,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"KSERVE_INFERENCE_SERVICE": service_instance.crd_name,
}
@staticmethod
def get_active_model_deployer() -> "KServeModelDeployer":
"""Get the KServe model deployer registered in the active stack.
Returns:
The KServe model deployer registered in the active stack.
Raises:
TypeError: if the KServe model deployer is not available.
"""
model_deployer = Client( # type: ignore [call-arg]
skip_client_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, KServeModelDeployer
):
raise TypeError(
f"The active stack needs to have a KServe model deployer "
f"component registered to be able to deploy models with KServe "
f"You can create a new stack with a KServe model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register kserve --flavor={KSERVE_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d kserve ...'\n"
)
return model_deployer
@property
def kserve_client(self) -> KServeClient:
"""Get the KServe client associated with this model deployer.
Returns:
The KServeclient.
"""
if not self._client:
self._client = KServeClient(
context=self.config.kubernetes_context,
)
return self._client
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(KSERVE_DOCKER_IMAGE_KEY, repo_digest)
def _set_credentials(self) -> None:
"""Set the credentials for the given service instance.
Raises:
RuntimeError: if the credentials are not available.
"""
secret = self._get_kserve_secret()
if secret:
secret_folder = Path(
GlobalConfiguration().config_directory,
"kserve-storage",
str(self.id),
)
kserve_credentials = {}
# Handle the secrets attributes
for key in secret.content.keys():
content = getattr(secret, key)
if key == "credentials" and content:
fileio.makedirs(str(secret_folder))
file_path = Path(secret_folder, f"{key}.json")
kserve_credentials["credentials_file"] = str(file_path)
with open(file_path, "w") as f:
f.write(content)
file_path.chmod(0o600)
# Handle additional params
else:
kserve_credentials[key] = content
# We need to add the namespace to the kserve_credentials
kserve_credentials["namespace"] = (
self.config.kubernetes_namespace
or utils.get_default_target_namespace()
)
try:
self.kserve_client.set_credentials(**kserve_credentials)
except Exception as e:
raise RuntimeError(
f"Failed to set credentials for KServe model deployer: {e}"
)
finally:
if file_path.exists():
file_path.unlink()
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new KServe deployment or update an existing one.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new KServe
deployment server to reflect the model and other configuration
parameters specified in the supplied KServe deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing KServe deployment that is *equivalent* to the supplied
configuration parameters. Two or more KServe deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two KServe deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent KServe deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing KServe deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new KServe deployment
server for each new model version. If multiple equivalent KServe
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with KServe.
replace: set this flag to True to find and update an equivalent
KServeDeployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the KServe server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the KServe
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML KServe deployment service object that can be used to
interact with the remote KServe server.
Raises:
RuntimeError: if the KServe deployment server could not be stopped.
"""
config = cast(KServeDeploymentConfig, config)
service = None
# if the secret is passed in the config, use it to set the credentials
if config.secret_name:
self.config.secret = config.secret_name or self.config.secret
self._set_credentials()
# if replace is True, find equivalent KServe deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError as e:
raise RuntimeError(
"Failed to stop the KServe deployment server:\n",
f"{e}\n",
"Please stop it manually and try again.",
)
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing KServe deployment service: {service}"
)
else:
# create a new service
service = KServeDeploymentService(config=config)
logger.info(f"Creating a new KServe deployment service: {service}")
# start the service which in turn provisions the KServe
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
# Add telemetry with metadata that gets the stack metadata and
# differentiates between pure model and custom code deployments
stack = Client().active_stack
stack_metadata = {
component_type.value: component.flavor
for component_type, component in stack.components.items()
}
metadata = {
"store_type": Client().zen_store.type.value,
**stack_metadata,
"is_custom_code_deployment": config.container is not None,
}
track_event(AnalyticsEvent.MODEL_DEPLOYED, metadata=metadata)
return service
def get_kserve_deployments(
self, labels: Dict[str, str]
) -> List[V1beta1InferenceService]:
"""Get a list of KServe deployments that match the supplied labels.
Args:
labels: a dictionary of labels to match against KServe deployments.
Returns:
A list of KServe deployments that match the supplied labels.
Raises:
RuntimeError: if an operational failure is encountered while
"""
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
namespace = (
self.config.kubernetes_namespace
or utils.get_default_target_namespace()
)
try:
response = (
self.kserve_client.api_instance.list_namespaced_custom_object(
constants.KSERVE_GROUP,
constants.KSERVE_V1BETA1_VERSION,
namespace,
constants.KSERVE_PLURAL,
label_selector=label_selector,
)
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when retrieving KServe inference services\
%s\n"
% e
)
# TODO[CRITICAL]: de-serialize each item into a complete
# V1beta1InferenceService object recursively using the OpenApi
# schema (this doesn't work right now)
inference_services: List[V1beta1InferenceService] = []
for item in response.get("items", []):
snake_case_item = self._camel_to_snake(item)
inference_service = V1beta1InferenceService(**snake_case_item)
inference_services.append(inference_service)
return inference_services
def _camel_to_snake(self, obj: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a camelCase dictionary to snake_case.
Args:
obj: a dictionary with camelCase keys
Returns:
a dictionary with snake_case keys
"""
if isinstance(obj, (str, int, float)):
return obj
if isinstance(obj, dict):
assert obj is not None
new = obj.__class__()
for k, v in obj.items():
new[self._convert_to_snake(k)] = self._camel_to_snake(v)
elif isinstance(obj, (list, set, tuple)):
assert obj is not None
new = obj.__class__(self._camel_to_snake(v) for v in obj)
else:
return obj
return new
def _convert_to_snake(self, k: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", k).lower()
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
predictor: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more KServe model services that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
predictor: the name of the predictor that was used to deploy the model.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
"""
config = KServeDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_uri=model_uri or "",
model_name=model_name or "",
predictor=predictor or "",
resources={},
)
labels = config.get_kubernetes_labels()
if service_uuid:
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.get_kserve_deployments(labels=labels)
services: List[BaseService] = []
for deployment in deployments:
# recreate the KServe deployment service object from the KServe
# deployment resource
service = KServeDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a KServe model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping on KServe model servers is not
supported.
"""
raise NotImplementedError(
"Stopping KServe model servers is not implemented. Try "
"deleting the KServe model server instead."
)
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a KServe model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting KServe
model servers
"""
raise NotImplementedError(
"Starting KServe model servers is not implemented"
)
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a KServe model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
def _get_kserve_secret(self) -> Any:
"""Get the secret object for the KServe deployment.
Returns:
The secret object for the KServe deployment.
Raises:
RuntimeError: if the secret object is not found or secrets_manager is not set.
"""
if self.config.secret:
secret_manager = Client( # type: ignore [call-arg]
skip_client_check=True
).active_stack.secrets_manager
if not secret_manager or not isinstance(
secret_manager, BaseSecretsManager
):
raise RuntimeError(
f"The active stack doesn't have a secret manager component. "
f"The ZenML secret specified in the KServe Model "
f"Deployer configuration cannot be fetched: {self.config.secret}."
)
try:
secret = secret_manager.get_secret(self.config.secret)
return secret
except KeyError:
raise RuntimeError(
f"The secret `{self.config.secret}` used for your KServe Model"
f"Deployer configuration does not exist in your secrets "
f"manager `{secret_manager.name}`."
)
return None
config: KServeModelDeployerConfig
property
readonly
Returns the KServeModelDeployerConfig
config.
Returns:
Type | Description |
---|---|
KServeModelDeployerConfig |
The configuration. |
kserve_client: KServeClient
property
readonly
Get the KServe client associated with this model deployer.
Returns:
Type | Description |
---|---|
KServeClient |
The KServeclient. |
delete_model_server(self, uuid, timeout=300, force=False)
Delete a KServe model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to delete. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a KServe model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
deploy_model(self, config, replace=False, timeout=300)
Create a new KServe deployment or update an existing one.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new KServe deployment server to reflect the model and other configuration parameters specified in the supplied KServe deploymentconfig
. -
if
replace
is True, this method will first attempt to find an existing KServe deployment that is equivalent to the supplied configuration parameters. Two or more KServe deployments are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two KServe deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent KServe deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing KServe deployment to retain its prediction URL while performing a rolling update to serve a new model version.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new KServe deployment
server for each new model version. If multiple equivalent KServe
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServiceConfig |
the configuration of the model to be deployed with KServe. |
required |
replace |
bool |
set this flag to True to find and update an equivalent KServeDeployment server with the new model instead of starting a new deployment server. |
False |
timeout |
int |
the timeout in seconds to wait for the KServe server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the KServe server is provisioned, without waiting for it to fully start. |
300 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML KServe deployment service object that can be used to interact with the remote KServe server. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the KServe deployment server could not be stopped. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new KServe deployment or update an existing one.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new KServe
deployment server to reflect the model and other configuration
parameters specified in the supplied KServe deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing KServe deployment that is *equivalent* to the supplied
configuration parameters. Two or more KServe deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two KServe deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent KServe deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing KServe deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new KServe deployment
server for each new model version. If multiple equivalent KServe
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with KServe.
replace: set this flag to True to find and update an equivalent
KServeDeployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the KServe server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the KServe
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML KServe deployment service object that can be used to
interact with the remote KServe server.
Raises:
RuntimeError: if the KServe deployment server could not be stopped.
"""
config = cast(KServeDeploymentConfig, config)
service = None
# if the secret is passed in the config, use it to set the credentials
if config.secret_name:
self.config.secret = config.secret_name or self.config.secret
self._set_credentials()
# if replace is True, find equivalent KServe deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError as e:
raise RuntimeError(
"Failed to stop the KServe deployment server:\n",
f"{e}\n",
"Please stop it manually and try again.",
)
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing KServe deployment service: {service}"
)
else:
# create a new service
service = KServeDeploymentService(config=config)
logger.info(f"Creating a new KServe deployment service: {service}")
# start the service which in turn provisions the KServe
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
# Add telemetry with metadata that gets the stack metadata and
# differentiates between pure model and custom code deployments
stack = Client().active_stack
stack_metadata = {
component_type.value: component.flavor
for component_type, component in stack.components.items()
}
metadata = {
"store_type": Client().zen_store.type.value,
**stack_metadata,
"is_custom_code_deployment": config.container is not None,
}
track_event(AnalyticsEvent.MODEL_DEPLOYED, metadata=metadata)
return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, predictor=None)
Find one or more KServe model services that match the given criteria.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running |
bool |
If true, only running services will be returned. |
False |
service_uuid |
Optional[uuid.UUID] |
The UUID of the service that was originally used to deploy the model. |
None |
pipeline_name |
Optional[str] |
name of the pipeline that the deployed model was part of. |
None |
pipeline_run_id |
Optional[str] |
ID of the pipeline run which the deployed model was part of. |
None |
pipeline_step_name |
Optional[str] |
the name of the pipeline model deployment step that deployed the model. |
None |
model_name |
Optional[str] |
the name of the deployed model. |
None |
model_uri |
Optional[str] |
URI of the deployed model. |
None |
predictor |
Optional[str] |
the name of the predictor that was used to deploy the model. |
None |
Returns:
Type | Description |
---|---|
List[zenml.services.service.BaseService] |
One or more Service objects representing model servers that match the input search criteria. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
predictor: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more KServe model services that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
predictor: the name of the predictor that was used to deploy the model.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
"""
config = KServeDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_uri=model_uri or "",
model_name=model_name or "",
predictor=predictor or "",
resources={},
)
labels = config.get_kubernetes_labels()
if service_uuid:
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.get_kserve_deployments(labels=labels)
services: List[BaseService] = []
for deployment in deployments:
# recreate the KServe deployment service object from the KServe
# deployment resource
service = KServeDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
get_active_model_deployer()
staticmethod
Get the KServe model deployer registered in the active stack.
Returns:
Type | Description |
---|---|
KServeModelDeployer |
The KServe model deployer registered in the active stack. |
Exceptions:
Type | Description |
---|---|
TypeError |
if the KServe model deployer is not available. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "KServeModelDeployer":
"""Get the KServe model deployer registered in the active stack.
Returns:
The KServe model deployer registered in the active stack.
Raises:
TypeError: if the KServe model deployer is not available.
"""
model_deployer = Client( # type: ignore [call-arg]
skip_client_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, KServeModelDeployer
):
raise TypeError(
f"The active stack needs to have a KServe model deployer "
f"component registered to be able to deploy models with KServe "
f"You can create a new stack with a KServe model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register kserve --flavor={KSERVE_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d kserve ...'\n"
)
return model_deployer
get_kserve_deployments(self, labels)
Get a list of KServe deployments that match the supplied labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
a dictionary of labels to match against KServe deployments. |
required |
Returns:
Type | Description |
---|---|
List[kserve.models.v1beta1_inference_service.V1beta1InferenceService] |
A list of KServe deployments that match the supplied labels. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if an operational failure is encountered while |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def get_kserve_deployments(
self, labels: Dict[str, str]
) -> List[V1beta1InferenceService]:
"""Get a list of KServe deployments that match the supplied labels.
Args:
labels: a dictionary of labels to match against KServe deployments.
Returns:
A list of KServe deployments that match the supplied labels.
Raises:
RuntimeError: if an operational failure is encountered while
"""
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
namespace = (
self.config.kubernetes_namespace
or utils.get_default_target_namespace()
)
try:
response = (
self.kserve_client.api_instance.list_namespaced_custom_object(
constants.KSERVE_GROUP,
constants.KSERVE_V1BETA1_VERSION,
namespace,
constants.KSERVE_PLURAL,
label_selector=label_selector,
)
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when retrieving KServe inference services\
%s\n"
% e
)
# TODO[CRITICAL]: de-serialize each item into a complete
# V1beta1InferenceService object recursively using the OpenApi
# schema (this doesn't work right now)
inference_services: List[V1beta1InferenceService] = []
for item in response.get("items", []):
snake_case_item = self._camel_to_snake(item)
inference_service = V1beta1InferenceService(**snake_case_item)
inference_services.append(inference_service)
return inference_services
get_model_server_info(service_instance)
staticmethod
Return implementation specific information on the model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
KServeDeploymentService |
KServe deployment service object |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
A dictionary containing the model server information. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "KServeDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information on the model server.
Args:
service_instance: KServe deployment service object
Returns:
A dictionary containing the model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"PREDICTION_HOSTNAME": service_instance.prediction_hostname,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"KSERVE_INFERENCE_SERVICE": service_instance.crd_name,
}
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(KSERVE_DOCKER_IMAGE_KEY, repo_digest)
start_model_server(self, uuid, timeout=300)
Start a KServe model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to start. |
required |
timeout |
int |
timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active. |
300 |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
since we don't support starting KServe model servers |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a KServe model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting KServe
model servers
"""
raise NotImplementedError(
"Starting KServe model servers is not implemented"
)
stop_model_server(self, uuid, timeout=300, force=False)
Stop a KServe model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to stop. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
stopping on KServe model servers is not supported. |
Source code in zenml/integrations/kserve/model_deployers/kserve_model_deployer.py
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a KServe model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping on KServe model servers is not
supported.
"""
raise NotImplementedError(
"Stopping KServe model servers is not implemented. Try "
"deleting the KServe model server instead."
)
secret_schemas
special
Initialization of Kserve Secret Schemas.
These are secret schemas that can be used to authenticate Kserve to the Artifact Store used to store served ML models.
secret_schemas
Implementation for KServe secret schemas.
KServeAzureSecretSchema (BaseSecretSchema)
pydantic-model
KServe Azure Blob Storage credentials.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
Literal['Azure'] |
the storage type. Must be set to "GCS" for this schema. |
credentials |
Optional[str] |
the credentials to use. |
Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeAzureSecretSchema(BaseSecretSchema):
"""KServe Azure Blob Storage credentials.
Attributes:
storage_type: the storage type. Must be set to "GCS" for this schema.
credentials: the credentials to use.
"""
TYPE: ClassVar[str] = KSERVE_AZUREBLOB_SECRET_SCHEMA_TYPE
storage_type: Literal["Azure"] = "Azure"
credentials: Optional[str]
KServeGSSecretSchema (BaseSecretSchema)
pydantic-model
KServe GCS credentials.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
Literal['GCS'] |
the storage type. Must be set to "GCS" for this schema. |
credentials |
Optional[str] |
the credentials to use. |
service_account |
Optional[str] |
the service account. |
Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeGSSecretSchema(BaseSecretSchema):
"""KServe GCS credentials.
Attributes:
storage_type: the storage type. Must be set to "GCS" for this schema.
credentials: the credentials to use.
service_account: the service account.
"""
TYPE: ClassVar[str] = KSERVE_GS_SECRET_SCHEMA_TYPE
storage_type: Literal["GCS"] = "GCS"
credentials: Optional[str]
service_account: Optional[str]
KServeS3SecretSchema (BaseSecretSchema)
pydantic-model
KServe S3 credentials.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
Literal['S3'] |
the storage type. Must be set to "s3" for this schema. |
credentials |
Optional[str] |
the credentials to use. |
service_account |
Optional[str] |
the name of the service account. |
s3_endpoint |
Optional[str] |
the S3 endpoint. |
s3_region |
Optional[str] |
the S3 region. |
s3_use_https |
Optional[str] |
whether to use HTTPS. |
s3_verify_ssl |
Optional[str] |
whether to verify SSL. |
Source code in zenml/integrations/kserve/secret_schemas/secret_schemas.py
class KServeS3SecretSchema(BaseSecretSchema):
"""KServe S3 credentials.
Attributes:
storage_type: the storage type. Must be set to "s3" for this schema.
credentials: the credentials to use.
service_account: the name of the service account.
s3_endpoint: the S3 endpoint.
s3_region: the S3 region.
s3_use_https: whether to use HTTPS.
s3_verify_ssl: whether to verify SSL.
"""
TYPE: ClassVar[str] = KSERVE_S3_SECRET_SCHEMA_TYPE
storage_type: Literal["S3"] = "S3"
credentials: Optional[str]
service_account: Optional[str]
s3_endpoint: Optional[str]
s3_region: Optional[str]
s3_use_https: Optional[str]
s3_verify_ssl: Optional[str]
services
special
Initialization for KServe services.
kserve_deployment
Implementation for the KServe inference service.
KServeDeploymentConfig (ServiceConfig)
pydantic-model
KServe deployment service configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the model (or models) to serve. |
model_name |
str |
the name of the model. Multiple versions of the same model should use the same model name. |
secret_name |
Optional[str] |
the name of the secret containing the model. |
predictor |
str |
the KServe predictor used to serve the model. |
replicas |
int |
number of replicas to use for the prediction service. |
resources |
Optional[Dict[str, Any]] |
the Kubernetes resources to allocate for the prediction service. |
container |
Optional[Dict[str, Any]] |
the container to use for the custom prediction services. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
class KServeDeploymentConfig(ServiceConfig):
"""KServe deployment service configuration.
Attributes:
model_uri: URI of the model (or models) to serve.
model_name: the name of the model. Multiple versions of the same model
should use the same model name.
secret_name: the name of the secret containing the model.
predictor: the KServe predictor used to serve the model.
replicas: number of replicas to use for the prediction service.
resources: the Kubernetes resources to allocate for the prediction service.
container: the container to use for the custom prediction services.
"""
model_uri: str = ""
model_name: str
secret_name: Optional[str]
predictor: str
replicas: int = 1
container: Optional[Dict[str, Any]]
resources: Optional[Dict[str, Any]]
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: The labels to sanitize.
"""
# TODO[MEDIUM]: Move k8s label sanitization to a common module with all K8s utils.
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
def get_kubernetes_labels(self) -> Dict[str, str]:
"""Generate the labels for the KServe inference CRD from the service configuration.
These labels are attached to the KServe inference service CRD
and may be used as label selectors in lookup operations.
Returns:
The labels for the KServe inference service CRD.
"""
labels = {"app": "zenml"}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.predictor:
labels["zenml.model_type"] = self.predictor
self.sanitize_labels(labels)
return labels
def get_kubernetes_annotations(self) -> Dict[str, str]:
"""Generate the annotations for the KServe inference CRD the service configuration.
The annotations are used to store additional information about the
KServe ZenML service associated with the deployment that is
not available on the labels. One annotation is particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote KServe inference
service CRD.
Returns:
The annotations for the KServe inference service CRD.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentConfig":
"""Recreate a KServe service from a KServe deployment resource.
Args:
deployment: the KServe inference service CRD.
Returns:
The KServe ZenML service configuration corresponding to the given
KServe inference service CRD.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe ZenML service configuration.
"""
config_data = deployment.metadata.get("annotations").get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded KServe Inference Service resource contains an "
f"invalid or incompatible KServe ZenML service configuration: "
f"{config_data}"
) from e
return service_config
create_from_deployment(deployment)
classmethod
Recreate a KServe service from a KServe deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
V1beta1InferenceService |
the KServe inference service CRD. |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentConfig |
The KServe ZenML service configuration corresponding to the given KServe inference service CRD. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible KServe ZenML service configuration. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentConfig":
"""Recreate a KServe service from a KServe deployment resource.
Args:
deployment: the KServe inference service CRD.
Returns:
The KServe ZenML service configuration corresponding to the given
KServe inference service CRD.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe ZenML service configuration.
"""
config_data = deployment.metadata.get("annotations").get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded KServe Inference Service resource contains an "
f"invalid or incompatible KServe ZenML service configuration: "
f"{config_data}"
) from e
return service_config
get_kubernetes_annotations(self)
Generate the annotations for the KServe inference CRD the service configuration.
The annotations are used to store additional information about the KServe ZenML service associated with the deployment that is not available on the labels. One annotation is particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote KServe inference service CRD.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The annotations for the KServe inference service CRD. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_kubernetes_annotations(self) -> Dict[str, str]:
"""Generate the annotations for the KServe inference CRD the service configuration.
The annotations are used to store additional information about the
KServe ZenML service associated with the deployment that is
not available on the labels. One annotation is particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote KServe inference
service CRD.
Returns:
The annotations for the KServe inference service CRD.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
get_kubernetes_labels(self)
Generate the labels for the KServe inference CRD from the service configuration.
These labels are attached to the KServe inference service CRD and may be used as label selectors in lookup operations.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The labels for the KServe inference service CRD. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_kubernetes_labels(self) -> Dict[str, str]:
"""Generate the labels for the KServe inference CRD from the service configuration.
These labels are attached to the KServe inference service CRD
and may be used as label selectors in lookup operations.
Returns:
The labels for the KServe inference service CRD.
"""
labels = {"app": "zenml"}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.predictor:
labels["zenml.model_type"] = self.predictor
self.sanitize_labels(labels)
return labels
sanitize_labels(labels)
staticmethod
Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
The labels to sanitize. |
required |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: The labels to sanitize.
"""
# TODO[MEDIUM]: Move k8s label sanitization to a common module with all K8s utils.
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
KServeDeploymentService (BaseService)
pydantic-model
A ZenML service that represents a KServe inference service CRD.
Attributes:
Name | Type | Description |
---|---|---|
config |
KServeDeploymentConfig |
service configuration. |
status |
ServiceStatus |
service status. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
class KServeDeploymentService(BaseService):
"""A ZenML service that represents a KServe inference service CRD.
Attributes:
config: service configuration.
status: service status.
"""
SERVICE_TYPE = ServiceType(
name="kserve-deployment",
type="model-serving",
flavor="kserve",
description="KServe inference service",
)
config: KServeDeploymentConfig = Field(
default_factory=KServeDeploymentConfig
)
status: ServiceStatus = Field(default_factory=ServiceStatus)
def _get_model_deployer(self) -> "KServeModelDeployer":
"""Get the active KServe model deployer.
Returns:
The active KServeModelDeployer.
Raises:
TypeError: if the current stack has no KServeModelDeployer.
"""
from zenml.integrations.kserve.model_deployers.kserve_model_deployer import (
KServeModelDeployer,
)
try:
model_deployer = KServeModelDeployer.get_active_model_deployer()
except TypeError:
raise TypeError(
"No active KServe model deployer is present in the active "
"stack. Please make sure that a KServe model deployer is "
"present in the active stack."
)
return model_deployer
def _get_client(self) -> KServeClient:
"""Get the KServe client from the active KServe model deployer.
Returns:
The KServe client.
"""
return self._get_model_deployer().kserve_client
def _get_namespace(self) -> Optional[str]:
"""Get the Kubernetes namespace from the active KServe model deployer.
Returns:
The Kubernetes namespace, or None, if the default namespace is
used.
"""
return self._get_model_deployer().config.kubernetes_namespace
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the state of the KServe inference service.
This method Checks the current operational state of the external KServe
inference service and translate it into a `ServiceState` value and a printable message.
This method should be overridden by subclasses that implement concrete service tracking functionality.
Returns:
The operational state of the external service and a message
providing additional information about that state (e.g. a
description of the error if one is encountered while checking the
service status).
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
try:
deployment = client.get(name=name, namespace=namespace)
except RuntimeError:
return (ServiceState.INACTIVE, "")
# TODO[MEDIUM]: Implement better operational status checking that also
# cover errors
if "status" not in deployment:
return (ServiceState.INACTIVE, "No operational status available")
status = "Unknown"
for condition in deployment["status"].get("conditions", {}):
if condition.get("type", "") == "PredictorReady":
status = condition.get("status", "Unknown")
if status.lower() == "true":
return (
ServiceState.ACTIVE,
f"Inference service '{name}' is available",
)
elif status.lower() == "false":
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' is not available: {condition.get('message', 'Unknown')}",
)
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' still starting up",
)
@property
def crd_name(self) -> str:
"""Get the name of the KServe inference service CRD that uniquely corresponds to this service instance.
Returns:
The name of the KServe inference service CRD.
"""
return (
self._get_kubernetes_labels().get("zenml.model_name")
or f"zenml-{str(self.uuid)[:8]}"
)
def _get_kubernetes_labels(self) -> Dict[str, str]:
"""Generate the labels for the KServe inference service CRD from the service configuration.
Returns:
The labels for the KServe inference service.
"""
labels = self.config.get_kubernetes_labels()
labels["zenml.service_uuid"] = str(self.uuid)
KServeDeploymentConfig.sanitize_labels(labels)
return labels
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentService":
"""Recreate the configuration of a KServe Service from a deployed instance.
Args:
deployment: the KServe deployment resource.
Returns:
The KServe service configuration corresponding to the given
KServe deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe service configuration.
"""
config = KServeDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.get("labels").get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
def provision(self) -> None:
"""Provision or update remote KServe deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
namespace = self._get_namespace()
api_version = constants.KSERVE_GROUP + "/" + "v1beta1"
name = self.crd_name
# All supported model specs seem to have the same fields
# so we can use any one of them (see https://kserve.github.io/website/0.8/reference/api/#serving.kserve.io/v1beta1.PredictorExtensionSpec)
if self.config.container is not None:
predictor_kwargs = {
"containers": [
k8s_client.V1Container(
name=self.config.container.get("name"),
image=self.config.container.get("image"),
command=self.config.container.get("command"),
args=self.config.container.get("args"),
env=[
k8s_client.V1EnvVar(
name="STORAGE_URI",
value=self.config.container.get("storage_uri"),
)
],
)
]
}
else:
predictor_kwargs = {
self.config.predictor: V1beta1PredictorExtensionSpec(
storage_uri=self.config.model_uri,
resources=self.config.resources,
)
}
isvc = V1beta1InferenceService(
api_version=api_version,
kind=constants.KSERVE_KIND,
metadata=k8s_client.V1ObjectMeta(
name=name,
namespace=namespace,
labels=self._get_kubernetes_labels(),
annotations=self.config.get_kubernetes_annotations(),
),
spec=V1beta1InferenceServiceSpec(
predictor=V1beta1PredictorSpec(**predictor_kwargs)
),
)
# TODO[HIGH]: better error handling when provisioning KServe instances
try:
client.get(name=name, namespace=namespace)
# update the existing deployment
client.replace(name, isvc, namespace=namespace)
except RuntimeError:
client.create(isvc)
def deprovision(self, force: bool = False) -> None:
"""Deprovisions all resources used by the service.
Args:
force: if True, the service will be deprovisioned even if it is
still in use.
Raises:
ValueError: if the service is still in use and force is False.
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
# TODO[HIGH]: catch errors if deleting a KServe instance that is no
# longer available
try:
client.delete(name=name, namespace=namespace)
except RuntimeError:
raise ValueError(
f"Could not delete KServe instance '{name}' from namespace: '{namespace}'."
)
def _get_deployment_logs(
self,
name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a KServe deployment resource.
Args:
name: the name of the KServe deployment to get logs for.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Raises:
Exception: if an unknown error occurs while fetching the logs.
Yields:
The logs of the given deployment.
"""
client = self._get_client()
namespace = self._get_namespace()
logger.debug(f"Retrieving logs for InferenceService resource: {name}")
try:
response = client.core_api.list_namespaced_pod(
namespace=namespace,
label_selector=f"zenml.service_uuid={self.uuid}",
)
logger.debug("Kubernetes API response: %s", response)
pods = response.items
if not pods:
raise Exception(
f"The KServe deployment {name} is not currently "
f"running: no Kubernetes pods associated with it were found"
)
pod = pods[0]
pod_name = pod.metadata.name
containers = [c.name for c in pod.spec.containers]
init_containers = [c.name for c in pod.spec.init_containers]
container_statuses = {
c.name: c.started or c.restart_count
for c in pod.status.container_statuses
}
container = "default"
if container not in containers:
container = containers[0]
if not container_statuses[container]:
container = init_containers[0]
logger.info(
f"Retrieving logs for pod: `{pod_name}` and container "
f"`{container}` in namespace `{namespace}`"
)
response = client.core_api.read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
container=container,
follow=follow,
tail_lines=tail,
_preload_content=False,
)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when fetching logs for InferenceService resource "
"%s: %s",
name,
str(e),
)
raise Exception(
f"Unexpected exception when fetching logs for InferenceService "
f"resource: {name}"
) from e
try:
while True:
line = response.readline().decode("utf-8").rstrip("\n")
if not line:
return
stop = yield line
if stop:
return
finally:
response.release_conn()
def get_logs(
self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
"""Retrieve the logs from the remote KServe inference service instance.
Args:
follow: if True, the logs will be streamed as they are written.
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_deployment_logs(
self.crd_name,
follow=follow,
tail=tail,
)
@property
def prediction_url(self) -> Optional[str]:
"""The prediction URI exposed by the prediction service.
Returns:
The prediction URI exposed by the prediction service, or None if
the service is not yet ready.
"""
if not self.is_running:
return None
model_deployer = self._get_model_deployer()
return os.path.join(
model_deployer.config.base_url,
"v1/models",
f"{self.crd_name}:predict",
)
@property
def prediction_hostname(self) -> Optional[str]:
"""The prediction hostname exposed by the prediction service.
Returns:
The prediction hostname exposed by the prediction service status
that will be used in the headers of the prediction request.
"""
if not self.is_running:
return None
namespace = self._get_namespace()
model_deployer = self._get_model_deployer()
custom_domain = model_deployer.config.custom_domain or "example.com"
return f"{self.crd_name}.{namespace}.{custom_domain}"
def predict(self, request: str) -> Any:
"""Make a prediction using the service.
Args:
request: a NumPy array representing the request
Returns:
A NumPy array represents the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"KServe prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
if self.prediction_hostname is None:
raise ValueError(
"`self.prediction_hostname` is not set, cannot post."
)
headers = {"Host": self.prediction_hostname}
if isinstance(request, str):
request = json.loads(request)
else:
raise ValueError("Request must be a json string.")
response = requests.post(
self.prediction_url,
headers=headers,
json={"instances": request},
)
response.raise_for_status()
return response.json()
crd_name: str
property
readonly
Get the name of the KServe inference service CRD that uniquely corresponds to this service instance.
Returns:
Type | Description |
---|---|
str |
The name of the KServe inference service CRD. |
prediction_hostname: Optional[str]
property
readonly
The prediction hostname exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction hostname exposed by the prediction service status that will be used in the headers of the prediction request. |
prediction_url: Optional[str]
property
readonly
The prediction URI exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction URI exposed by the prediction service, or None if the service is not yet ready. |
check_status(self)
Check the state of the KServe inference service.
This method Checks the current operational state of the external KServe
inference service and translate it into a ServiceState
value and a printable message.
This method should be overridden by subclasses that implement concrete service tracking functionality.
Returns:
Type | Description |
---|---|
Tuple[zenml.services.service_status.ServiceState, str] |
The operational state of the external service and a message providing additional information about that state (e.g. a description of the error if one is encountered while checking the service status). |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the state of the KServe inference service.
This method Checks the current operational state of the external KServe
inference service and translate it into a `ServiceState` value and a printable message.
This method should be overridden by subclasses that implement concrete service tracking functionality.
Returns:
The operational state of the external service and a message
providing additional information about that state (e.g. a
description of the error if one is encountered while checking the
service status).
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
try:
deployment = client.get(name=name, namespace=namespace)
except RuntimeError:
return (ServiceState.INACTIVE, "")
# TODO[MEDIUM]: Implement better operational status checking that also
# cover errors
if "status" not in deployment:
return (ServiceState.INACTIVE, "No operational status available")
status = "Unknown"
for condition in deployment["status"].get("conditions", {}):
if condition.get("type", "") == "PredictorReady":
status = condition.get("status", "Unknown")
if status.lower() == "true":
return (
ServiceState.ACTIVE,
f"Inference service '{name}' is available",
)
elif status.lower() == "false":
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' is not available: {condition.get('message', 'Unknown')}",
)
return (
ServiceState.PENDING_STARTUP,
f"Inference service '{name}' still starting up",
)
create_from_deployment(deployment)
classmethod
Recreate the configuration of a KServe Service from a deployed instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
V1beta1InferenceService |
the KServe deployment resource. |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentService |
The KServe service configuration corresponding to the given KServe deployment resource. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible KServe service configuration. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: V1beta1InferenceService
) -> "KServeDeploymentService":
"""Recreate the configuration of a KServe Service from a deployed instance.
Args:
deployment: the KServe deployment resource.
Returns:
The KServe service configuration corresponding to the given
KServe deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible KServe service configuration.
"""
config = KServeDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.get("labels").get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
deprovision(self, force=False)
Deprovisions all resources used by the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
force |
bool |
if True, the service will be deprovisioned even if it is still in use. |
False |
Exceptions:
Type | Description |
---|---|
ValueError |
if the service is still in use and force is False. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def deprovision(self, force: bool = False) -> None:
"""Deprovisions all resources used by the service.
Args:
force: if True, the service will be deprovisioned even if it is
still in use.
Raises:
ValueError: if the service is still in use and force is False.
"""
client = self._get_client()
namespace = self._get_namespace()
name = self.crd_name
# TODO[HIGH]: catch errors if deleting a KServe instance that is no
# longer available
try:
client.delete(name=name, namespace=namespace)
except RuntimeError:
raise ValueError(
f"Could not delete KServe instance '{name}' from namespace: '{namespace}'."
)
get_logs(self, follow=False, tail=None)
Retrieve the logs from the remote KServe inference service instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
follow |
bool |
if True, the logs will be streamed as they are written. |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def get_logs(
self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
"""Retrieve the logs from the remote KServe inference service instance.
Args:
follow: if True, the logs will be streamed as they are written.
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_deployment_logs(
self.crd_name,
follow=follow,
tail=tail,
)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
str |
a NumPy array representing the request |
required |
Returns:
Type | Description |
---|---|
Any |
A NumPy array represents the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not yet ready. |
ValueError |
if the prediction_url is not set. |
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def predict(self, request: str) -> Any:
"""Make a prediction using the service.
Args:
request: a NumPy array representing the request
Returns:
A NumPy array represents the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"KServe prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
if self.prediction_hostname is None:
raise ValueError(
"`self.prediction_hostname` is not set, cannot post."
)
headers = {"Host": self.prediction_hostname}
if isinstance(request, str):
request = json.loads(request)
else:
raise ValueError("Request must be a json string.")
response = requests.post(
self.prediction_url,
headers=headers,
json={"instances": request},
)
response.raise_for_status()
return response.json()
provision(self)
Provision or update remote KServe deployment instance.
This should then match the current configuration.
Source code in zenml/integrations/kserve/services/kserve_deployment.py
def provision(self) -> None:
"""Provision or update remote KServe deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
namespace = self._get_namespace()
api_version = constants.KSERVE_GROUP + "/" + "v1beta1"
name = self.crd_name
# All supported model specs seem to have the same fields
# so we can use any one of them (see https://kserve.github.io/website/0.8/reference/api/#serving.kserve.io/v1beta1.PredictorExtensionSpec)
if self.config.container is not None:
predictor_kwargs = {
"containers": [
k8s_client.V1Container(
name=self.config.container.get("name"),
image=self.config.container.get("image"),
command=self.config.container.get("command"),
args=self.config.container.get("args"),
env=[
k8s_client.V1EnvVar(
name="STORAGE_URI",
value=self.config.container.get("storage_uri"),
)
],
)
]
}
else:
predictor_kwargs = {
self.config.predictor: V1beta1PredictorExtensionSpec(
storage_uri=self.config.model_uri,
resources=self.config.resources,
)
}
isvc = V1beta1InferenceService(
api_version=api_version,
kind=constants.KSERVE_KIND,
metadata=k8s_client.V1ObjectMeta(
name=name,
namespace=namespace,
labels=self._get_kubernetes_labels(),
annotations=self.config.get_kubernetes_annotations(),
),
spec=V1beta1InferenceServiceSpec(
predictor=V1beta1PredictorSpec(**predictor_kwargs)
),
)
# TODO[HIGH]: better error handling when provisioning KServe instances
try:
client.get(name=name, namespace=namespace)
# update the existing deployment
client.replace(name, isvc, namespace=namespace)
except RuntimeError:
client.create(isvc)
steps
special
Initialization for KServe steps.
kserve_deployer
Implementation of the KServe Deployer step.
CustomDeployParameters (BaseModel)
pydantic-model
Custom model deployer step extra parameters.
Attributes:
Name | Type | Description |
---|---|---|
predict_function |
str |
Path to Python file containing predict function. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class CustomDeployParameters(BaseModel):
"""Custom model deployer step extra parameters.
Attributes:
predict_function: Path to Python file containing predict function.
"""
predict_function: str
@validator("predict_function")
def predict_function_validate(cls, predict_func_path: str) -> str:
"""Validate predict function.
Args:
predict_func_path: predict function path
Returns:
predict function path
Raises:
ValueError: if predict function path is not valid
TypeError: if predict function path is not a callable function
"""
try:
predict_function = import_class_by_path(predict_func_path)
except AttributeError:
raise ValueError("Predict function can't be found.")
if not callable(predict_function):
raise TypeError("Predict function must be callable.")
return predict_func_path
predict_function_validate(predict_func_path)
classmethod
Validate predict function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
predict_func_path |
str |
predict function path |
required |
Returns:
Type | Description |
---|---|
str |
predict function path |
Exceptions:
Type | Description |
---|---|
ValueError |
if predict function path is not valid |
TypeError |
if predict function path is not a callable function |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("predict_function")
def predict_function_validate(cls, predict_func_path: str) -> str:
"""Validate predict function.
Args:
predict_func_path: predict function path
Returns:
predict function path
Raises:
ValueError: if predict function path is not valid
TypeError: if predict function path is not a callable function
"""
try:
predict_function = import_class_by_path(predict_func_path)
except AttributeError:
raise ValueError("Predict function can't be found.")
if not callable(predict_function):
raise TypeError("Predict function must be callable.")
return predict_func_path
KServeDeployerStepParameters (BaseParameters)
pydantic-model
KServe model deployer step parameters.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
KServeDeploymentConfig |
KServe deployment service configuration. |
torch_serve_params |
TorchServe set of parameters to deploy model. |
|
timeout |
int |
Timeout for model deployment. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepParameters(BaseParameters):
"""KServe model deployer step parameters.
Attributes:
service_config: KServe deployment service configuration.
torch_serve_params: TorchServe set of parameters to deploy model.
timeout: Timeout for model deployment.
"""
service_config: KServeDeploymentConfig
custom_deploy_parameters: Optional[CustomDeployParameters] = None
torch_serve_parameters: Optional[TorchServeParameters] = None
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
TorchServeParameters (BaseModel)
pydantic-model
KServe PyTorch model deployer step configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_class |
str |
Path to Python file containing model architecture. |
handler |
str |
TorchServe's handler file to handle custom TorchServe inference logic. |
extra_files |
Optional[List[str]] |
Comma separated path to extra dependency files. |
model_version |
Optional[str] |
Model version. |
requirements_file |
Optional[str] |
Path to requirements file. |
torch_config |
Optional[str] |
TorchServe configuration file path. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class TorchServeParameters(BaseModel):
"""KServe PyTorch model deployer step configuration.
Attributes:
model_class: Path to Python file containing model architecture.
handler: TorchServe's handler file to handle custom TorchServe inference
logic.
extra_files: Comma separated path to extra dependency files.
model_version: Model version.
requirements_file: Path to requirements file.
torch_config: TorchServe configuration file path.
"""
model_class: str
handler: str
extra_files: Optional[List[str]] = None
requirements_file: Optional[str] = None
model_version: Optional[str] = "1.0"
torch_config: Optional[str] = None
@validator("model_class")
def model_class_validate(cls, v: str) -> str:
"""Validate model class file path.
Args:
v: model class file path
Returns:
model class file path
Raises:
ValueError: if model class file path is not valid
"""
if not v:
raise ValueError("Model class file path is required.")
if not is_inside_repository(v):
raise ValueError(
"Model class file path must be inside the repository."
)
return v
@validator("handler")
def handler_validate(cls, v: str) -> str:
"""Validate handler.
Args:
v: handler file path
Returns:
handler file path
Raises:
ValueError: if handler file path is not valid
"""
if v:
if v in TORCH_HANDLERS:
return v
elif is_inside_repository(v):
return v
else:
raise ValueError(
"Handler must be one of the TorchServe handlers",
"or a file that exists inside the repository.",
)
else:
raise ValueError("Handler is required.")
@validator("extra_files")
def extra_files_validate(
cls, v: Optional[List[str]]
) -> Optional[List[str]]:
"""Validate extra files.
Args:
v: extra files path
Returns:
extra files path
Raises:
ValueError: if the extra files path is not valid
"""
extra_files = []
if v is not None:
for file_path in v:
if is_inside_repository(file_path):
extra_files.append(file_path)
else:
raise ValueError(
"Extra file path must be inside the repository."
)
return extra_files
return v
@validator("torch_config")
def torch_config_validate(cls, v: Optional[str]) -> Optional[str]:
"""Validate torch config file.
Args:
v: torch config file path
Returns:
torch config file path
Raises:
ValueError: if torch config file path is not valid.
"""
if v:
if is_inside_repository(v):
return v
else:
raise ValueError(
"Torch config file path must be inside the repository."
)
return v
extra_files_validate(v)
classmethod
Validate extra files.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
Optional[List[str]] |
extra files path |
required |
Returns:
Type | Description |
---|---|
Optional[List[str]] |
extra files path |
Exceptions:
Type | Description |
---|---|
ValueError |
if the extra files path is not valid |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("extra_files")
def extra_files_validate(
cls, v: Optional[List[str]]
) -> Optional[List[str]]:
"""Validate extra files.
Args:
v: extra files path
Returns:
extra files path
Raises:
ValueError: if the extra files path is not valid
"""
extra_files = []
if v is not None:
for file_path in v:
if is_inside_repository(file_path):
extra_files.append(file_path)
else:
raise ValueError(
"Extra file path must be inside the repository."
)
return extra_files
return v
handler_validate(v)
classmethod
Validate handler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
str |
handler file path |
required |
Returns:
Type | Description |
---|---|
str |
handler file path |
Exceptions:
Type | Description |
---|---|
ValueError |
if handler file path is not valid |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("handler")
def handler_validate(cls, v: str) -> str:
"""Validate handler.
Args:
v: handler file path
Returns:
handler file path
Raises:
ValueError: if handler file path is not valid
"""
if v:
if v in TORCH_HANDLERS:
return v
elif is_inside_repository(v):
return v
else:
raise ValueError(
"Handler must be one of the TorchServe handlers",
"or a file that exists inside the repository.",
)
else:
raise ValueError("Handler is required.")
model_class_validate(v)
classmethod
Validate model class file path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
str |
model class file path |
required |
Returns:
Type | Description |
---|---|
str |
model class file path |
Exceptions:
Type | Description |
---|---|
ValueError |
if model class file path is not valid |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("model_class")
def model_class_validate(cls, v: str) -> str:
"""Validate model class file path.
Args:
v: model class file path
Returns:
model class file path
Raises:
ValueError: if model class file path is not valid
"""
if not v:
raise ValueError("Model class file path is required.")
if not is_inside_repository(v):
raise ValueError(
"Model class file path must be inside the repository."
)
return v
torch_config_validate(v)
classmethod
Validate torch config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
Optional[str] |
torch config file path |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
torch config file path |
Exceptions:
Type | Description |
---|---|
ValueError |
if torch config file path is not valid. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@validator("torch_config")
def torch_config_validate(cls, v: Optional[str]) -> Optional[str]:
"""Validate torch config file.
Args:
v: torch config file path
Returns:
torch config file path
Raises:
ValueError: if torch config file path is not valid.
"""
if v:
if is_inside_repository(v):
return v
else:
raise ValueError(
"Torch config file path must be inside the repository."
)
return v
kserve_custom_model_deployer_step (BaseStep)
KServe custom model deployer pipeline step.
This step can be used in a pipeline to implement the process required to deploy a custom model with KServe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
params |
parameters for the deployer step |
required | |
model |
the model artifact to deploy |
required | |
context |
the step context |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the custom deployer parameters is not defined |
DoesNotExistException |
if no active stack is found |
Returns:
Type | Description |
---|---|
KServe deployment service |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
KServe model deployer step parameters.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
KServeDeploymentConfig |
KServe deployment service configuration. |
torch_serve_params |
TorchServe set of parameters to deploy model. |
|
timeout |
int |
Timeout for model deployment. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepParameters(BaseParameters):
"""KServe model deployer step parameters.
Attributes:
service_config: KServe deployment service configuration.
torch_serve_params: TorchServe set of parameters to deploy model.
timeout: Timeout for model deployment.
"""
service_config: KServeDeploymentConfig
custom_deploy_parameters: Optional[CustomDeployParameters] = None
torch_serve_parameters: Optional[TorchServeParameters] = None
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, params, context, model)
staticmethod
KServe custom model deployer pipeline step.
This step can be used in a pipeline to implement the process required to deploy a custom model with KServe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
params |
KServeDeployerStepParameters |
parameters for the deployer step |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
context |
StepContext |
the step context |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the custom deployer parameters is not defined |
DoesNotExistException |
if no active stack is found |
Returns:
Type | Description |
---|---|
KServeDeploymentService |
KServe deployment service |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@step(enable_cache=False)
def kserve_custom_model_deployer_step(
deploy_decision: bool,
params: KServeDeployerStepParameters,
context: StepContext,
model: ModelArtifact,
) -> KServeDeploymentService:
"""KServe custom model deployer pipeline step.
This step can be used in a pipeline to implement the
process required to deploy a custom model with KServe.
Args:
deploy_decision: whether to deploy the model or not
params: parameters for the deployer step
model: the model artifact to deploy
context: the step context
Raises:
ValueError: if the custom deployer parameters is not defined
DoesNotExistException: if no active stack is found
Returns:
KServe deployment service
"""
# verify that a custom deployer is defined
if not params.custom_deploy_parameters:
raise ValueError(
"Custom deploy parameter which contains the path of the",
"custom predict function is required for custom model deployment.",
)
# get the active model deployer
model_deployer = KServeModelDeployer.get_active_model_deployer()
# get pipeline name, step name, run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
pipeline_run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# update the step configuration with the real pipeline runtime information
params.service_config.pipeline_name = pipeline_name
params.service_config.pipeline_run_id = pipeline_run_id
params.service_config.pipeline_step_name = step_name
# fetch existing services with same pipeline name, step name and
# model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=params.service_config.model_name,
)
# even when the deploy decision is negative if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing the last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{params.service_config.model_name}'..."
)
service = cast(KServeDeploymentService, existing_services[0])
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure that
# a model server is available at all times
if not service.is_running:
service.start(timeout=params.timeout)
return service
# entrypoint for starting KServe server deployment for custom model
entrypoint_command = [
"python",
"-m",
"zenml.integrations.kserve.custom_deployer.zenml_custom_model",
"--model_name",
params.service_config.model_name,
"--predict_func",
params.custom_deploy_parameters.predict_function,
]
# verify if there is an active stack before starting the service
if not context.stack:
raise DoesNotExistException(
"No active stack is available. "
"Please make sure that you have registered and set a stack."
)
context.stack
docker_image = step_env.step_run_info.pipeline.extra[
KSERVE_DOCKER_IMAGE_KEY
]
# copy the model files to a new specific directory for the deployment
served_model_uri = os.path.join(context.get_output_artifact_uri(), "kserve")
fileio.makedirs(served_model_uri)
io_utils.copy_dir(model.uri, served_model_uri)
# Get the model artifact to extract information about the model
# and how it can be loaded again later in the deployment environment.
artifact = Client().zen_store.list_artifacts(artifact_uri=model.uri)
if not artifact:
raise DoesNotExistException(f"No artifact found at {model.uri}.")
# save the model artifact metadata to the YAML file and copy it to the
# deployment directory
model_metadata_file = save_model_metadata(artifact[0])
fileio.copy(
model_metadata_file,
os.path.join(served_model_uri, MODEL_METADATA_YAML_FILE_NAME),
)
# prepare the service configuration for the deployment
service_config = params.service_config.copy()
service_config.model_uri = served_model_uri
# Prepare container config for custom model deployment
service_config.container = {
"name": service_config.model_name,
"image": docker_image,
"command": entrypoint_command,
"storage_uri": service_config.model_uri,
}
# deploy the service
service = cast(
KServeDeploymentService,
model_deployer.deploy_model(
service_config, replace=True, timeout=params.timeout
),
)
logger.info(
f"KServe deployment service started and reachable at:\n"
f" {service.prediction_url}\n"
f" With the hostname: {service.prediction_hostname}."
)
return service
kserve_model_deployer_step (BaseStep)
KServe model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for an ML model with KServe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
params |
parameters for the deployer step |
required | |
model |
the model artifact to deploy |
required | |
context |
the step context |
required |
Returns:
Type | Description |
---|---|
KServe deployment service |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
KServe model deployer step parameters.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
KServeDeploymentConfig |
KServe deployment service configuration. |
torch_serve_params |
TorchServe set of parameters to deploy model. |
|
timeout |
int |
Timeout for model deployment. |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
class KServeDeployerStepParameters(BaseParameters):
"""KServe model deployer step parameters.
Attributes:
service_config: KServe deployment service configuration.
torch_serve_params: TorchServe set of parameters to deploy model.
timeout: Timeout for model deployment.
"""
service_config: KServeDeploymentConfig
custom_deploy_parameters: Optional[CustomDeployParameters] = None
torch_serve_parameters: Optional[TorchServeParameters] = None
timeout: int = DEFAULT_KSERVE_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, params, context, model)
staticmethod
KServe model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for an ML model with KServe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
params |
KServeDeployerStepParameters |
parameters for the deployer step |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
context |
StepContext |
the step context |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentService |
KServe deployment service |
Source code in zenml/integrations/kserve/steps/kserve_deployer.py
@step(enable_cache=False)
def kserve_model_deployer_step(
deploy_decision: bool,
params: KServeDeployerStepParameters,
context: StepContext,
model: ModelArtifact,
) -> KServeDeploymentService:
"""KServe model deployer pipeline step.
This step can be used in a pipeline to implement continuous
deployment for an ML model with KServe.
Args:
deploy_decision: whether to deploy the model or not
params: parameters for the deployer step
model: the model artifact to deploy
context: the step context
Returns:
KServe deployment service
"""
model_deployer = KServeModelDeployer.get_active_model_deployer()
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
pipeline_run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# update the step configuration with the real pipeline runtime information
params.service_config.pipeline_name = pipeline_name
params.service_config.pipeline_run_id = pipeline_run_id
params.service_config.pipeline_step_name = step_name
# fetch existing services with same pipeline name, step name and
# model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=params.service_config.model_name,
)
# even when the deploy decision is negative if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing the last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{params.service_config.model_name}'..."
)
service = cast(KServeDeploymentService, existing_services[0])
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure that
# a model server is available at all times
if not service.is_running:
service.start(timeout=params.timeout)
return service
# invoke the KServe model deployer to create a new service
# or update an existing one that was previously deployed for the same
# model
if params.service_config.predictor == "pytorch":
# import the prepare function from the step utils
from zenml.integrations.kserve.steps.kserve_step_utils import (
prepare_torch_service_config,
)
# prepare the service config
service_config = prepare_torch_service_config(
model_uri=model.uri,
output_artifact_uri=context.get_output_artifact_uri(),
params=params,
)
else:
# import the prepare function from the step utils
from zenml.integrations.kserve.steps.kserve_step_utils import (
prepare_service_config,
)
# prepare the service config
service_config = prepare_service_config(
model_uri=model.uri,
output_artifact_uri=context.get_output_artifact_uri(),
params=params,
)
service = cast(
KServeDeploymentService,
model_deployer.deploy_model(
service_config, replace=True, timeout=params.timeout
),
)
logger.info(
f"KServe deployment service started and reachable at:\n"
f" {service.prediction_url}\n"
f" With the hostname: {service.prediction_hostname}."
)
return service
kserve_step_utils
This module contains the utility functions used by the KServe deployer step.
TorchModelArchiver (BaseModel)
pydantic-model
Model Archiver for PyTorch models.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
Model name. |
model_version |
Model version. |
|
serialized_file |
str |
Serialized model file. |
handler |
str |
TorchServe's handler file to handle custom TorchServe inference logic. |
extra_files |
Optional[List[str]] |
Comma separated path to extra dependency files. |
requirements_file |
Optional[str] |
Path to requirements file. |
export_path |
str |
Path to export model. |
runtime |
Optional[str] |
Runtime of the model. |
force |
Optional[bool] |
Force export of the model. |
archive_format |
Optional[str] |
Archive format. |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
class TorchModelArchiver(BaseModel):
"""Model Archiver for PyTorch models.
Attributes:
model_name: Model name.
model_version: Model version.
serialized_file: Serialized model file.
handler: TorchServe's handler file to handle custom TorchServe inference logic.
extra_files: Comma separated path to extra dependency files.
requirements_file: Path to requirements file.
export_path: Path to export model.
runtime: Runtime of the model.
force: Force export of the model.
archive_format: Archive format.
"""
model_name: str
serialized_file: str
model_file: str
handler: str
export_path: str
extra_files: Optional[List[str]] = None
version: Optional[str] = None
requirements_file: Optional[str] = None
runtime: Optional[str] = "python"
force: Optional[bool] = None
archive_format: Optional[str] = "default"
generate_model_deployer_config(model_name, directory)
Generate a model deployer config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
the name of the model |
required |
directory |
str |
the directory where the model is stored |
required |
Returns:
Type | Description |
---|---|
str |
None |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def generate_model_deployer_config(
model_name: str,
directory: str,
) -> str:
"""Generate a model deployer config.
Args:
model_name: the name of the model
directory: the directory where the model is stored
Returns:
None
"""
config_lines = [
"inference_address=http://0.0.0.0:8085",
"management_address=http://0.0.0.0:8085",
"metrics_address=http://0.0.0.0:8082",
"grpc_inference_port=7070",
"grpc_management_port=7071",
"enable_metrics_api=true",
"metrics_format=prometheus",
"number_of_netty_threads=4",
"job_queue_size=10",
"enable_envvars_config=true",
"install_py_dep_per_model=true",
"model_store=/mnt/models/model-store",
]
with tempfile.NamedTemporaryFile(
suffix=".properties", mode="w+", dir=directory, delete=False
) as f:
for line in config_lines:
f.write(line + "\n")
f.write(
f'model_snapshot={{"name":"startup.cfg","modelCount":1,"models":{{"{model_name}":{{"1.0":{{"defaultVersion":true,"marName":"{model_name}.mar","minWorkers":1,"maxWorkers":5,"batchSize":1,"maxBatchDelay":10,"responseTimeout":120}}}}}}}}'
)
f.close()
return f.name
prepare_service_config(model_uri, output_artifact_uri, params)
Prepare the model files for model serving.
This function ensures that the model files are in the correct format and file structure required by the KServe server implementation used for model serving.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the URI of the model artifact being served |
required |
output_artifact_uri |
str |
the URI of the output artifact |
required |
params |
KServeDeployerStepParameters |
the KServe deployer step parameters |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentConfig |
The URL to the model is ready for serving. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the model files cannot be prepared. |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def prepare_service_config(
model_uri: str,
output_artifact_uri: str,
params: KServeDeployerStepParameters,
) -> KServeDeploymentConfig:
"""Prepare the model files for model serving.
This function ensures that the model files are in the correct format
and file structure required by the KServe server implementation
used for model serving.
Args:
model_uri: the URI of the model artifact being served
output_artifact_uri: the URI of the output artifact
params: the KServe deployer step parameters
Returns:
The URL to the model is ready for serving.
Raises:
RuntimeError: if the model files cannot be prepared.
"""
served_model_uri = os.path.join(output_artifact_uri, "kserve")
fileio.makedirs(served_model_uri)
# TODO [ENG-773]: determine how to formalize how models are organized into
# folders and sub-folders depending on the model type/format and the
# KServe protocol used to serve the model.
# TODO [ENG-791]: an auto-detect built-in KServe server implementation
# from the model artifact type
# TODO [ENG-792]: validate the model artifact type against the
# supported built-in KServe server implementations
if params.service_config.predictor == "tensorflow":
# the TensorFlow server expects model artifacts to be
# stored in numbered subdirectories, each representing a model
# version
served_model_uri = os.path.join(
served_model_uri,
params.service_config.predictor,
params.service_config.model_name,
)
fileio.makedirs(served_model_uri)
io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
elif params.service_config.predictor == "sklearn":
# the sklearn server expects model artifacts to be
# stored in a file called model.joblib
model_uri = os.path.join(model_uri, "model")
if not fileio.exists(model_uri):
raise RuntimeError(
f"Expected sklearn model artifact was not found at "
f"{model_uri}"
)
served_model_uri = os.path.join(
served_model_uri,
params.service_config.predictor,
params.service_config.model_name,
)
fileio.makedirs(served_model_uri)
fileio.copy(model_uri, os.path.join(served_model_uri, "model.joblib"))
else:
# default treatment for all other server implementations is to
# simply reuse the model from the artifact store path where it
# is originally stored
served_model_uri = os.path.join(
served_model_uri,
params.service_config.predictor,
params.service_config.model_name,
)
fileio.makedirs(served_model_uri)
fileio.copy(model_uri, served_model_uri)
service_config = params.service_config.copy()
service_config.model_uri = served_model_uri
return service_config
prepare_torch_service_config(model_uri, output_artifact_uri, params)
Prepare the PyTorch model files for model serving.
This function ensures that the model files are in the correct format and file structure required by the KServe server implementation used for model serving.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the URI of the model artifact being served |
required |
output_artifact_uri |
str |
the URI of the output artifact |
required |
params |
KServeDeployerStepParameters |
the KServe deployer step parameters |
required |
Returns:
Type | Description |
---|---|
KServeDeploymentConfig |
The URL to the model is ready for serving. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the model files cannot be prepared. |
Source code in zenml/integrations/kserve/steps/kserve_step_utils.py
def prepare_torch_service_config(
model_uri: str,
output_artifact_uri: str,
params: KServeDeployerStepParameters,
) -> KServeDeploymentConfig:
"""Prepare the PyTorch model files for model serving.
This function ensures that the model files are in the correct format
and file structure required by the KServe server implementation
used for model serving.
Args:
model_uri: the URI of the model artifact being served
output_artifact_uri: the URI of the output artifact
params: the KServe deployer step parameters
Returns:
The URL to the model is ready for serving.
Raises:
RuntimeError: if the model files cannot be prepared.
"""
deployment_folder_uri = os.path.join(output_artifact_uri, "kserve")
served_model_uri = os.path.join(deployment_folder_uri, "model-store")
config_propreties_uri = os.path.join(deployment_folder_uri, "config")
fileio.makedirs(served_model_uri)
fileio.makedirs(config_propreties_uri)
if params.torch_serve_parameters is None:
raise RuntimeError("No torch serve parameters provided")
else:
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-pytorch-temp-")
tmp_model_uri = os.path.join(
str(temp_dir), f"{params.service_config.model_name}.pt"
)
# Copy from artifact store to temporary file
fileio.copy(f"{model_uri}/checkpoint.pt", tmp_model_uri)
torch_archiver_args = TorchModelArchiver(
model_name=params.service_config.model_name,
serialized_file=tmp_model_uri,
model_file=params.torch_serve_parameters.model_class,
handler=params.torch_serve_parameters.handler,
export_path=temp_dir,
version=params.torch_serve_parameters.model_version,
)
manifest = ModelExportUtils.generate_manifest_json(torch_archiver_args)
package_model(torch_archiver_args, manifest=manifest)
# Copy from temporary file to artifact store
archived_model_uri = os.path.join(
temp_dir, f"{params.service_config.model_name}.mar"
)
if not fileio.exists(archived_model_uri):
raise RuntimeError(
f"Expected torch archived model artifact was not found at "
f"{archived_model_uri}"
)
# Copy the torch model archive artifact to the model store
fileio.copy(
archived_model_uri,
os.path.join(
served_model_uri, f"{params.service_config.model_name}.mar"
),
)
# Get or Generate the config file
if params.torch_serve_parameters.torch_config:
# Copy the torch model config to the model store
fileio.copy(
params.torch_serve_parameters.torch_config,
os.path.join(config_propreties_uri, "config.properties"),
)
else:
# Generate the config file
config_file_uri = generate_model_deployer_config(
model_name=params.service_config.model_name,
directory=temp_dir,
)
# Copy the torch model config to the model store
fileio.copy(
config_file_uri,
os.path.join(config_propreties_uri, "config.properties"),
)
service_config = params.service_config.copy()
service_config.model_uri = deployment_folder_uri
return service_config
kubeflow
special
Initialization of the Kubeflow integration for ZenML.
The Kubeflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool.
KubeflowIntegration (Integration)
Definition of Kubeflow Integration for ZenML.
Source code in zenml/integrations/kubeflow/__init__.py
class KubeflowIntegration(Integration):
"""Definition of Kubeflow Integration for ZenML."""
NAME = KUBEFLOW
REQUIREMENTS = ["kfp==1.8.13"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Kubeflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.kubeflow.flavors import (
KubeflowOrchestratorFlavor,
)
return [KubeflowOrchestratorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Kubeflow integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/kubeflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Kubeflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.kubeflow.flavors import (
KubeflowOrchestratorFlavor,
)
return [KubeflowOrchestratorFlavor]
flavors
special
Kubeflow integration flavors.
kubeflow_orchestrator_flavor
Kubeflow orchestrator flavor.
KubeflowOrchestratorConfig (BaseOrchestratorConfig)
pydantic-model
Configuration for the Kubeflow orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
kubeflow_pipelines_ui_port |
int |
A local port to which the KFP UI will be forwarded. |
kubeflow_hostname |
Optional[str] |
The hostname to use to talk to the Kubeflow Pipelines API. If not set, the hostname will be derived from the Kubernetes API proxy. |
kubeflow_namespace |
str |
The Kubernetes namespace in which Kubeflow
Pipelines is deployed. Defaults to |
kubernetes_context |
Optional[str] |
Optional name of a kubernetes context to run pipelines in. If not set, will try to spin up a local K3d cluster. |
synchronous |
bool |
If |
skip_local_validations |
bool |
If |
skip_cluster_provisioning |
bool |
If |
skip_ui_daemon_provisioning |
bool |
If |
Source code in zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py
class KubeflowOrchestratorConfig(BaseOrchestratorConfig):
"""Configuration for the Kubeflow orchestrator.
Attributes:
kubeflow_pipelines_ui_port: A local port to which the KFP UI will be
forwarded.
kubeflow_hostname: The hostname to use to talk to the Kubeflow Pipelines
API. If not set, the hostname will be derived from the Kubernetes
API proxy.
kubeflow_namespace: The Kubernetes namespace in which Kubeflow
Pipelines is deployed. Defaults to `kubeflow`.
kubernetes_context: Optional name of a kubernetes context to run
pipelines in. If not set, will try to spin up a local K3d cluster.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on KFP.
skip_local_validations: If `True`, the local validations will be
skipped.
skip_cluster_provisioning: If `True`, the k3d cluster provisioning will
be skipped.
skip_ui_daemon_provisioning: If `True`, provisioning the KFP UI daemon
will be skipped.
"""
kubeflow_pipelines_ui_port: int = DEFAULT_KFP_UI_PORT
kubeflow_hostname: Optional[str] = None
kubeflow_namespace: str = "kubeflow"
kubernetes_context: Optional[str] = None
synchronous: bool = False
skip_local_validations: bool = False
skip_cluster_provisioning: bool = False
skip_ui_daemon_provisioning: bool = False
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
if (
self.kubernetes_context is not None
and not self.kubernetes_context.startswith("k3d-zenml-kubeflow-")
):
return True
return False
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
This designation is used to determine if the stack component can be
shared with other users or if it is only usable on the local host.
Returns:
True if this config is for a local component, False otherwise.
"""
if (
self.kubernetes_context is None
or self.kubernetes_context.startswith("k3d-zenml-kubeflow-")
):
return True
return False
is_local: bool
property
readonly
Checks if this stack component is running locally.
This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
KubeflowOrchestratorFlavor (BaseOrchestratorFlavor)
Kubeflow orchestrator flavor.
Source code in zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py
class KubeflowOrchestratorFlavor(BaseOrchestratorFlavor):
"""Kubeflow orchestrator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return KUBEFLOW_ORCHESTRATOR_FLAVOR
@property
def config_class(self) -> Type[KubeflowOrchestratorConfig]:
"""Returns `KubeflowOrchestratorConfig` config class.
Returns:
The config class.
"""
return KubeflowOrchestratorConfig
@property
def implementation_class(self) -> Type["KubeflowOrchestrator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.kubeflow.orchestrators import (
KubeflowOrchestrator,
)
return KubeflowOrchestrator
config_class: Type[zenml.integrations.kubeflow.flavors.kubeflow_orchestrator_flavor.KubeflowOrchestratorConfig]
property
readonly
Returns KubeflowOrchestratorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.kubeflow.flavors.kubeflow_orchestrator_flavor.KubeflowOrchestratorConfig] |
The config class. |
implementation_class: Type[KubeflowOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[KubeflowOrchestrator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
KubeflowOrchestratorSettings (BaseSettings)
pydantic-model
Settings for the Kubeflow orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
client_args |
Dict[str, Any] |
Arguments to pass when initializing the KFP client. |
user_namespace |
Optional[str] |
The user namespace to use when creating experiments and runs. |
Source code in zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py
class KubeflowOrchestratorSettings(BaseSettings):
"""Settings for the Kubeflow orchestrator.
Attributes:
client_args: Arguments to pass when initializing the KFP client.
user_namespace: The user namespace to use when creating experiments
and runs.
"""
LEVEL: ClassVar[ConfigurationLevel] = ConfigurationLevel.PIPELINE
client_args: Dict[str, Any] = {}
user_namespace: Optional[str] = None
orchestrators
special
Initialization of the Kubeflow ZenML orchestrator.
kubeflow_entrypoint_configuration
Implementation of the Kubeflow entrypoint configuration.
KubeflowEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on kubeflow.
This class writes a markdown file that will be displayed in the KFP UI.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
class KubeflowEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on kubeflow.
This class writes a markdown file that will be displayed in the KFP UI.
"""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
The metadata ui path option expects a path where the markdown file
that will be displayed in the kubeflow UI should be written. The same
path needs to be added as an output artifact called
`mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.
Returns:
The superclass options as well as an option for the metadata ui
path.
"""
return super().get_entrypoint_options() | {METADATA_UI_PATH_OPTION}
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the metadata ui path.
Returns:
The superclass arguments as well as arguments for the metadata ui
path.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{METADATA_UI_PATH_OPTION}",
kwargs[METADATA_UI_PATH_OPTION],
]
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the Kubeflow pipeline run name.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Kubeflow pipeline run name.
Raises:
RuntimeError: If the run name environment variable is not set.
"""
try:
return os.environ[ENV_ZENML_RUN_NAME]
except KeyError:
raise RuntimeError(
"Unable to read run name from environment variable "
f"{ENV_ZENML_RUN_NAME}."
)
def post_run(
self,
pipeline_name: str,
step_name: str,
execution_info: Optional[data_types.ExecutionInfo] = None,
) -> None:
"""Writes a markdown file that will display information.
This will be about the step execution and input/output artifacts in the
KFP UI.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
execution_info: The execution info of the step.
"""
if execution_info:
utils.dump_ui_metadata(
execution_info=execution_info,
metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
)
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs, must include the metadata ui path. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the metadata ui path. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the metadata ui path.
Returns:
The superclass arguments as well as arguments for the metadata ui
path.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{METADATA_UI_PATH_OPTION}",
kwargs[METADATA_UI_PATH_OPTION],
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
The metadata ui path option expects a path where the markdown file
that will be displayed in the kubeflow UI should be written. The same
path needs to be added as an output artifact called
mlpipeline-ui-metadata
for the corresponding kfp.dsl.ContainerOp
.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the metadata ui path. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
The metadata ui path option expects a path where the markdown file
that will be displayed in the kubeflow UI should be written. The same
path needs to be added as an output artifact called
`mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.
Returns:
The superclass options as well as an option for the metadata ui
path.
"""
return super().get_entrypoint_options() | {METADATA_UI_PATH_OPTION}
get_run_name(self, pipeline_name)
Returns the Kubeflow pipeline run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The Kubeflow pipeline run name. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run name environment variable is not set. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the Kubeflow pipeline run name.
Args:
pipeline_name: The name of the pipeline.
Returns:
The Kubeflow pipeline run name.
Raises:
RuntimeError: If the run name environment variable is not set.
"""
try:
return os.environ[ENV_ZENML_RUN_NAME]
except KeyError:
raise RuntimeError(
"Unable to read run name from environment variable "
f"{ENV_ZENML_RUN_NAME}."
)
post_run(self, pipeline_name, step_name, execution_info=None)
Writes a markdown file that will display information.
This will be about the step execution and input/output artifacts in the KFP UI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
step_name |
str |
The name of the step. |
required |
execution_info |
Optional[tfx.orchestration.portable.data_types.ExecutionInfo] |
The execution info of the step. |
None |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def post_run(
self,
pipeline_name: str,
step_name: str,
execution_info: Optional[data_types.ExecutionInfo] = None,
) -> None:
"""Writes a markdown file that will display information.
This will be about the step execution and input/output artifacts in the
KFP UI.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
execution_info: The execution info of the step.
"""
if execution_info:
utils.dump_ui_metadata(
execution_info=execution_info,
metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
)
kubeflow_orchestrator
Implementation of the Kubeflow orchestrator.
KubeflowOrchestrator (BaseOrchestrator)
Orchestrator responsible for running pipelines using Kubeflow.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
class KubeflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Kubeflow."""
@property
def config(self) -> KubeflowOrchestratorConfig:
"""Returns the `KubeflowOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(KubeflowOrchestratorConfig, self._config)
@staticmethod
def _get_k3d_cluster_name(uuid: UUID) -> str:
"""Returns the k3d cluster name corresponding to the orchestrator UUID.
Args:
uuid: The UUID of the orchestrator.
Returns:
The k3d cluster name.
"""
# k3d only allows cluster names with up to 32 characters; use the
# first 8 chars of the orchestrator UUID as identifier
return f"zenml-kubeflow-{str(uuid)[:8]}"
@staticmethod
def _get_k3d_kubernetes_context(uuid: UUID) -> str:
"""Gets the k3d kubernetes context.
Args:
uuid: The UUID of the orchestrator.
Returns:
The name of the kubernetes context associated with the k3d
cluster managed locally by ZenML corresponding to the orchestrator UUID.
"""
return f"k3d-{KubeflowOrchestrator._get_k3d_cluster_name(uuid)}"
@property
def kubernetes_context(self) -> str:
"""Gets the kubernetes context associated with the orchestrator.
This sets the default `kubernetes_context` value to the value that is
used to create the locally managed k3d cluster, if not explicitly set.
Returns:
The kubernetes context associated with the orchestrator.
"""
if self.config.kubernetes_context:
return self.config.kubernetes_context
return self._get_k3d_kubernetes_context(self.id)
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
"""Get the list of configured Kubernetes contexts and the active context.
Returns:
A tuple containing the list of configured Kubernetes contexts and
the active context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException:
return [], None
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Kubeflow orchestrator.
Returns:
The settings class.
"""
return KubeflowOrchestratorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Also check that requirements are met for local components.
Returns:
A `StackValidator` instance.
"""
def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
container_registry = stack.container_registry
# should not happen, because the stack validation takes care of
# this, but just in case
assert container_registry is not None
contexts, active_context = self.get_kubernetes_contexts()
if self.kubernetes_context not in contexts:
if not self.is_local:
return False, (
f"Could not find a Kubernetes context named "
f"'{self.kubernetes_context}' in the local Kubernetes "
f"configuration. Please make sure that the Kubernetes "
f"cluster is running and that the kubeconfig file is "
f"configured correctly. To list all configured "
f"contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
)
elif active_context and self.kubernetes_context != active_context:
logger.warning(
f"The Kubernetes context '{self.kubernetes_context}' "
f"configured for the Kubeflow orchestrator is not the "
f"same as the active context in the local Kubernetes "
f"configuration. If this is not deliberate, you should "
f"update the orchestrator's `kubernetes_context` field by "
f"running:\n\n"
f" `zenml orchestrator update {self.name} "
f"--kubernetes_context={active_context}`\n"
f"To list all configured contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
f"To set the active context to be the same as the one "
f"configured in the Kubeflow orchestrator and silence "
f"this warning, run:\n\n"
f" `kubectl config use-context "
f"{self.kubernetes_context}`\n"
)
silence_local_validations_msg = (
f"To silence this warning, set the "
f"`skip_local_validations` attribute to True in the "
f"orchestrator configuration by running:\n\n"
f" 'zenml orchestrator update {self.name} "
f"--skip_local_validations=True'\n"
)
if not self.config.skip_local_validations and not self.is_local:
# if the orchestrator is not running in a local k3d cluster,
# we cannot have any other local components in our stack,
# because we cannot mount the local path into the container.
# This may result in problems when running the pipeline, because
# the local components will not be available inside the
# Kubeflow containers.
# go through all stack components and identify those that
# advertise a local path where they persist information that
# they need to be available when running pipelines.
for stack_comp in stack.components.values():
local_path = stack_comp.local_path
if not local_path:
continue
return False, (
f"The Kubeflow orchestrator is configured to run "
f"pipelines in a remote Kubernetes cluster designated "
f"by the '{self.kubernetes_context}' configuration "
f"context, but the '{stack_comp.name}' "
f"{stack_comp.type.value} is a local stack component "
f"and will not be available in the Kubeflow pipeline "
f"step.\nPlease ensure that you always use non-local "
f"stack components with a remote Kubeflow orchestrator, "
f"otherwise you may run into pipeline execution "
f"problems. You should use a flavor of "
f"{stack_comp.type.value} other than "
f"'{stack_comp.flavor}'.\n"
+ silence_local_validations_msg
)
# if the orchestrator is remote, the container registry must
# also be remote.
if container_registry.config.is_local:
return False, (
f"The Kubeflow orchestrator is configured to run "
f"pipelines in a remote Kubernetes cluster designated "
f"by the '{self.kubernetes_context}' configuration "
f"context, but the '{container_registry.name}' "
f"container registry URI "
f"'{container_registry.config.uri}' "
f"points to a local container registry. Please ensure "
f"that you always use non-local stack components with "
f"a remote Kubeflow orchestrator, otherwise you will "
f"run into problems. You should use a flavor of "
f"container registry other than "
f"'{container_registry.flavor}'.\n"
+ silence_local_validations_msg
)
if not self.config.skip_local_validations and self.is_local:
# if the orchestrator is local, the container registry must
# also be local.
if not container_registry.config.is_local:
return False, (
f"The Kubeflow orchestrator is configured to run "
f"pipelines in a local k3d Kubernetes cluster "
f"designated by the '{self.kubernetes_context}' "
f"configuration context, but the container registry "
f"URI '{container_registry.config.uri}' doesn't "
f"match the expected format 'localhost:$PORT'. "
f"The local Kubeflow orchestrator only works with a "
f"local container registry because it cannot "
f"currently authenticate to external container "
f"registries. You should use a flavor of container "
f"registry other than '{container_registry.flavor}'.\n"
+ silence_local_validations_msg
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_local_requirements,
)
@property
def is_local(self) -> bool:
"""Checks if the KFP orchestrator is running locally.
Returns:
`True` if the KFP orchestrator is running locally (i.e. in
the local k3d cluster managed by ZenML).
"""
return self.kubernetes_context == self._get_k3d_kubernetes_context(
self.id
)
@property
def root_directory(self) -> str:
"""Path to the root directory for all files concerning this orchestrator.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"kubeflow",
str(self.id),
)
@property
def pipeline_directory(self) -> str:
"""Returns path to a directory in which the kubeflow pipeline files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
def _configure_container_op(
self, container_op: dsl.ContainerOp, is_scheduled_run: bool
) -> None:
"""Makes changes in place to the configuration of the container op.
Configures persistent mounted volumes for each stack component that
writes to a local path. Adds some labels to the container_op and applies
some functions to ir.
Args:
container_op: The kubeflow container operation to configure.
is_scheduled_run: Whether the pipeline is scheduled or a single run.
Raises:
ValueError: If the local path is not in the global config directory.
"""
# Path to a metadata file that will be displayed in the KFP UI
# This metadata file needs to be in a mounted emptyDir to avoid
# sporadic failures with the (not mature) PNS executor
# See these links for more information about limitations of PNS +
# security context:
# https://www.kubeflow.org/docs/components/pipelines/installation/localcluster-deployment/#deploying-kubeflow-pipelines
# https://argoproj.github.io/argo-workflows/empty-dir/
# KFP will switch to the Emissary executor (soon), when this emptyDir
# mount will not be necessary anymore, but for now it's still in alpha
# status (https://www.kubeflow.org/docs/components/pipelines/installation/choose-executor/#emissary-executor)
volumes: Dict[str, k8s_client.V1Volume] = {
"/outputs": k8s_client.V1Volume(
name="outputs", empty_dir=k8s_client.V1EmptyDirVolumeSource()
),
}
stack = Client().active_stack
local_stores_path = GlobalConfiguration().local_stores_path
# go through all stack components and identify those that advertise
# a local path where they persist information that they need to be
# available when running pipelines.
has_local_paths = False
for stack_comp in stack.components.values():
local_path = stack_comp.local_path
if not local_path:
continue
# double-check this convention, just in case it wasn't respected
# as documented in `StackComponent.local_path`
if not local_path.startswith(local_stores_path):
raise ValueError(
f"Local path {local_path} for component "
f"{stack_comp.name} is not in the local stores "
f"directory ({local_stores_path})."
)
has_local_paths = True
logger.debug(
"The host path for %s %s (path: %s) will be mounted "
"in the kubeflow pipelines container.",
stack_comp.type.value,
stack_comp.name,
local_path,
)
if has_local_paths or self.is_local:
host_path = k8s_client.V1HostPathVolumeSource(
path=local_stores_path, type="Directory"
)
volumes[local_stores_path] = k8s_client.V1Volume(
name="local-stores",
host_path=host_path,
)
logger.debug(
"Adding host path volume for the local ZenML stores (path: %s) "
"in kubeflow pipelines container.",
local_stores_path,
)
if sys.platform == "win32":
# File permissions are not checked on Windows. This if clause
# prevents mypy from complaining about unused 'type: ignore'
# statements
pass
else:
# Run KFP containers in the context of the local UID/GID
# to ensure that the artifact and metadata stores can be shared
# with the local pipeline runs.
container_op.container.security_context = (
k8s_client.V1SecurityContext(
run_as_user=os.getuid(),
run_as_group=os.getgid(),
)
)
logger.debug(
"Setting security context UID and GID to local user/group "
"in kubeflow pipelines container."
)
container_op.add_pvolumes(volumes)
# Add some pod labels to the container_op
for k, v in KFP_POD_LABELS.items():
container_op.add_pod_label(k, v)
run_name = (
SCHEDULED_RUN_NAME_PLACEHOLDER
if is_scheduled_run
else SINGLE_RUN_RUN_NAME_PLACEHOLDER
)
container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=ENV_ZENML_RUN_NAME,
value=run_name,
)
)
# Mounts configmap containing Metadata gRPC server configuration.
container_op.apply(utils.mount_config_map_op("metadata-grpc-configmap"))
@staticmethod
def _configure_container_resources(
container_op: dsl.ContainerOp,
resource_settings: "ResourceSettings",
) -> None:
"""Adds resource requirements to the container.
Args:
container_op: The kubeflow container operation to configure.
resource_settings: The resource settings to use for this
container.
"""
if resource_settings.cpu_count is not None:
container_op = container_op.set_cpu_limit(
str(resource_settings.cpu_count)
)
if resource_settings.gpu_count is not None:
container_op = container_op.set_gpu_limit(
resource_settings.gpu_count
)
if resource_settings.memory is not None:
memory_limit = resource_settings.memory[:-1]
container_op = container_op.set_memory_limit(memory_limit)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Creates a kfp yaml file.
This functions as an intermediary representation of the pipeline which
is then deployed to the kubeflow pipelines instance.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()`
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these container_ops are then also configured onto each
container_op by pointing at the downstream steps.
This callable is then compiled into a kfp yaml file that is used as
the intermediary representation of the kubeflow pipeline.
This file, together with some metadata, runtime configurations is
then uploaded into the kubeflow pipelines cluster for execution.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
RuntimeError: If trying to run a pipeline in a notebook environment.
"""
# First check whether the code running in a notebook
if Environment.in_notebook():
raise RuntimeError(
"The Kubeflow orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
is_scheduled_run = bool(deployment.schedule)
# Create a callable for future compilation into a dsl.Pipeline.
def _construct_kfp_pipeline() -> None:
"""Create a container_op for each step.
This should contain the name of the docker image and configures the
entrypoint of the docker image to run the step.
Additionally, this gives each container_op information about its
direct downstream steps.
If this callable is passed to the `_create_and_write_workflow()`
method of a KFPCompiler all dsl.ContainerOp instances will be
automatically added to a singular dsl.Pipeline instance.
"""
# Dictionary of container_ops index by the associated step name
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.steps.items():
# The command will be needed to eventually call the python step
# within the docker container
command = (
KubeflowEntrypointConfiguration.get_entrypoint_command()
)
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
arguments = (
KubeflowEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
**{METADATA_UI_PATH_OPTION: metadata_ui_path},
)
)
# Create a container_op - the kubeflow equivalent of a step. It
# contains the name of the step, the name of the docker image,
# the command to use to run the step entrypoint
# (e.g. `python -m zenml.entrypoints.step_entrypoint`)
# and the arguments to be passed along with the command. Find
# out more about how these arguments are parsed and used
# in the base entrypoint `run()` method.
container_op = dsl.ContainerOp(
name=step.config.name,
image=image_name,
command=command,
arguments=arguments,
output_artifact_paths={
"mlpipeline-ui-metadata": metadata_ui_path,
},
)
self._configure_container_op(
container_op=container_op, is_scheduled_run=is_scheduled_run
)
if self.requires_resources_in_orchestration_environment(step):
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
# Update dictionary of container ops with the current one
step_name_to_container_op[step.config.name] = container_op
# Get a filepath to use to save the finished yaml to
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{deployment.run_name}.yaml"
)
# write the argo pipeline yaml
KFPCompiler()._create_and_write_workflow(
pipeline_func=_construct_kfp_pipeline,
pipeline_name=deployment.pipeline.name,
package_path=pipeline_file_path,
)
# using the kfp client uploads the pipeline to kubeflow pipelines and
# runs it there
self._upload_and_run_pipeline(
deployment=deployment,
pipeline_file_path=pipeline_file_path,
)
def _upload_and_run_pipeline(
self,
deployment: "PipelineDeployment",
pipeline_file_path: str,
) -> None:
"""Tries to upload and run a KFP pipeline.
Args:
deployment: The pipeline deployment.
pipeline_file_path: Path to the pipeline definition file.
"""
pipeline_name = deployment.pipeline.name
run_name = deployment.run_name
enable_cache = deployment.pipeline.enable_cache
settings = cast(
Optional[KubeflowOrchestratorSettings],
self.get_settings(deployment),
)
user_namespace = settings.user_namespace if settings else None
try:
logger.info(
"Running in kubernetes context '%s'.",
self.kubernetes_context,
)
# upload the pipeline to Kubeflow and start it
client = self._get_kfp_client(settings=settings)
if deployment.schedule:
try:
experiment = client.get_experiment(
pipeline_name, namespace=user_namespace
)
logger.info(
"A recurring run has already been created with this "
"pipeline. Creating new recurring run now.."
)
except (ValueError, ApiException):
experiment = client.create_experiment(
pipeline_name, namespace=user_namespace
)
logger.info(
"Creating a new recurring run for pipeline '%s'.. ",
pipeline_name,
)
logger.info(
"You can see all recurring runs under the '%s' experiment.",
pipeline_name,
)
interval_seconds = (
deployment.schedule.interval_second.seconds
if deployment.schedule.interval_second
else None
)
result = client.create_recurring_run(
experiment_id=experiment.id,
job_name=run_name,
pipeline_package_path=pipeline_file_path,
enable_caching=enable_cache,
cron_expression=deployment.schedule.cron_expression,
start_time=deployment.schedule.utc_start_time,
end_time=deployment.schedule.utc_end_time,
interval_second=interval_seconds,
no_catchup=not deployment.schedule.catchup,
)
logger.info("Started recurring run with ID '%s'.", result.id)
else:
logger.info(
"No schedule detected. Creating a one-off pipeline run.."
)
result = client.create_run_from_pipeline_package(
pipeline_file_path,
arguments={},
run_name=run_name,
enable_caching=enable_cache,
namespace=user_namespace,
)
logger.info(
"Started one-off pipeline run with ID '%s'.", result.run_id
)
if self.config.synchronous:
# TODO [ENG-698]: Allow configuration of the timeout as a
# setting
client.wait_for_run_completion(
run_id=result.run_id, timeout=1200
)
except urllib3.exceptions.HTTPError as error:
logger.warning(
f"Failed to upload Kubeflow pipeline: %s. "
f"Please make sure your kubernetes config is present and the "
f"{self.kubernetes_context} kubernetes context is configured "
f"correctly.",
error,
)
def _get_kfp_client(
self,
settings: Optional[KubeflowOrchestratorSettings] = None,
) -> kfp.Client:
"""Creates a KFP client instance.
Args:
settings: Optional settings which can be used to
configure the client instance.
Returns:
A KFP client instance.
"""
client_args = {
"kube_context": self.config.kubernetes_context,
}
if settings:
client_args.update(settings.client_args)
# The host and namespace are stack component configurations that refer
# to the Kubeflow deployment. We don't want these overwritten on a
# run by run basis by user settings
client_args["host"] = self.config.kubeflow_hostname
client_args["namespace"] = self.config.kubeflow_namespace
return kfp.Client(**client_args)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "kubeflow_daemon.pid")
@property
def log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path of the daemon log file.
"""
return os.path.join(self.root_directory, "kubeflow_daemon.log")
@property
def _k3d_cluster_name(self) -> str:
"""Returns the K3D cluster name.
Returns:
The K3D cluster name.
"""
return self._get_k3d_cluster_name(self.id)
def _get_k3d_registry_name(self, port: int) -> str:
"""Returns the K3D registry name.
Args:
port: Port of the registry.
Returns:
The registry name.
"""
return f"k3d-zenml-kubeflow-registry.localhost:{port}"
@property
def _k3d_registry_config_path(self) -> str:
"""Returns the path to the K3D registry config yaml.
Returns:
str: Path to the K3D registry config yaml.
"""
return os.path.join(self.root_directory, "k3d_registry.yaml")
def _get_kfp_ui_daemon_port(self) -> int:
"""Port to use for the KFP UI daemon.
Returns:
Port to use for the KFP UI daemon.
"""
port = self.config.kubeflow_pipelines_ui_port
if port == DEFAULT_KFP_UI_PORT and not networking_utils.port_available(
port
):
# if the user didn't specify a specific port and the default
# port is occupied, fallback to a random open port
port = networking_utils.find_available_port()
return port
def list_manual_setup_steps(
self, container_registry_name: str, container_registry_path: str
) -> None:
"""Logs manual steps needed to setup the Kubeflow local orchestrator.
Args:
container_registry_name: Name of the container registry.
container_registry_path: Path to the container registry.
"""
if not self.is_local:
# Make sure we're not telling users to deploy Kubeflow on their
# remote clusters
logger.warning(
"This Kubeflow orchestrator is configured to use a non-local "
f"Kubernetes context {self.kubernetes_context}. Manually "
f"deploying Kubeflow Pipelines is only possible for local "
f"Kubeflow orchestrators."
)
return
global_config_dir_path = io_utils.get_global_config_directory()
kubeflow_commands = [
f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.config.kubeflow_pipelines_ui_port}:80",
]
logger.info(
"If you wish to spin up this Kubeflow local orchestrator manually, "
"please enter the following commands:\n"
)
logger.info("\n".join(kubeflow_commands))
@property
def is_provisioned(self) -> bool:
"""Returns if a local k3d cluster for this orchestrator exists.
Returns:
True if a local k3d cluster exists, False otherwise.
"""
if not local_deployment_utils.check_prerequisites(
skip_k3d=self.config.skip_cluster_provisioning or not self.is_local,
skip_kubectl=self.config.skip_cluster_provisioning
and self.config.skip_ui_daemon_provisioning,
):
# if any prerequisites are missing there is certainly no
# local deployment running
return False
return self.is_cluster_provisioned
@property
def is_running(self) -> bool:
"""Checks if the local k3d cluster and UI daemon are both running.
Returns:
True if the local k3d cluster and UI daemon for this orchestrator are both running.
"""
return (
self.is_provisioned
and self.is_cluster_running
and self.is_daemon_running
)
@property
def is_suspended(self) -> bool:
"""Checks if the local k3d cluster and UI daemon are both stopped.
Returns:
True if the cluster and daemon for this orchestrator are both stopped, False otherwise.
"""
return (
self.is_provisioned
and (
self.config.skip_cluster_provisioning
or not self.is_cluster_running
)
and (
self.config.skip_ui_daemon_provisioning
or not self.is_daemon_running
)
)
@property
def is_cluster_provisioned(self) -> bool:
"""Returns if the local k3d cluster for this orchestrator is provisioned.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
this always returns True.
Returns:
True if the local k3d cluster is provisioned, False otherwise.
"""
if self.config.skip_cluster_provisioning or not self.is_local:
return True
return local_deployment_utils.k3d_cluster_exists(
cluster_name=self._k3d_cluster_name
)
@property
def is_cluster_running(self) -> bool:
"""Returns if the local k3d cluster for this orchestrator is running.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
this always returns True.
Returns:
True if the local k3d cluster is running, False otherwise.
"""
if self.config.skip_cluster_provisioning or not self.is_local:
return True
return local_deployment_utils.k3d_cluster_running(
cluster_name=self._k3d_cluster_name
)
@property
def is_daemon_running(self) -> bool:
"""Returns if the local Kubeflow UI daemon for this orchestrator is running.
Returns:
True if the daemon is running, False otherwise.
"""
if self.config.skip_ui_daemon_provisioning:
return True
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
return check_if_daemon_is_running(self._pid_file_path)
else:
return True
def provision(self) -> None:
"""Provisions a local Kubeflow Pipelines deployment.
Raises:
ProvisioningError: If the provisioning fails.
"""
if self.config.skip_cluster_provisioning:
return
if self.is_running:
logger.info(
"Found already existing local Kubeflow Pipelines deployment. "
"If there are any issues with the existing deployment, please "
"run 'zenml stack down --force' to delete it."
)
return
if not local_deployment_utils.check_prerequisites():
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Please install 'k3d' and 'kubectl' and try again."
)
container_registry = Client().active_stack.container_registry
# should not happen, because the stack validation takes care of this,
# but just in case
assert container_registry is not None
fileio.makedirs(self.root_directory)
if not self.is_local:
# don't provision any resources if using a remote KFP installation
return
logger.info("Provisioning local Kubeflow Pipelines deployment...")
container_registry_port = int(
container_registry.config.uri.split(":")[-1]
)
container_registry_name = self._get_k3d_registry_name(
port=container_registry_port
)
local_deployment_utils.write_local_registry_yaml(
yaml_path=self._k3d_registry_config_path,
registry_name=container_registry_name,
registry_uri=container_registry.config.uri,
)
try:
local_deployment_utils.create_k3d_cluster(
cluster_name=self._k3d_cluster_name,
registry_name=container_registry_name,
registry_config_path=self._k3d_registry_config_path,
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
local_deployment_utils.deploy_kubeflow_pipelines(
kubernetes_context=kubernetes_context
)
artifact_store = Client().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
kubernetes_context=kubernetes_context,
local_path=artifact_store.path,
)
except Exception as e:
logger.error(e)
logger.error(
"Unable to spin up local Kubeflow Pipelines deployment."
)
self.list_manual_setup_steps(
container_registry_name, self._k3d_registry_config_path
)
self.deprovision()
def deprovision(self) -> None:
"""Deprovisions a local Kubeflow Pipelines deployment."""
if self.config.skip_cluster_provisioning:
return
if (
not self.config.skip_ui_daemon_provisioning
and self.is_daemon_running
):
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if self.is_local:
# don't deprovision any resources if using a remote KFP installation
local_deployment_utils.delete_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
if fileio.exists(self.log_file):
fileio.remove(self.log_file)
def resume(self) -> None:
"""Resumes the local k3d cluster.
Raises:
ProvisioningError: If the k3d cluster is not provisioned.
"""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
if not self.is_provisioned:
raise ProvisioningError(
"Unable to resume local kubeflow pipelines deployment: No "
"resources provisioned for local deployment."
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
if (
not self.config.skip_cluster_provisioning
and self.is_local
and not self.is_cluster_running
):
# don't resume any resources if using a remote KFP installation
local_deployment_utils.start_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.wait_until_kubeflow_pipelines_ready(
kubernetes_context=kubernetes_context
)
if not self.is_daemon_running:
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
kubernetes_context=kubernetes_context,
)
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_provisioned:
logger.info("Local kubeflow pipelines deployment not provisioned.")
return
if (
not self.config.skip_ui_daemon_provisioning
and self.is_daemon_running
):
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if (
not self.config.skip_cluster_provisioning
and self.is_local
and self.is_cluster_running
):
# don't suspend any resources if using a remote KFP installation
local_deployment_utils.stop_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
config: KubeflowOrchestratorConfig
property
readonly
Returns the KubeflowOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
KubeflowOrchestratorConfig |
The configuration. |
is_cluster_provisioned: bool
property
readonly
Returns if the local k3d cluster for this orchestrator is provisioned.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.
Returns:
Type | Description |
---|---|
bool |
True if the local k3d cluster is provisioned, False otherwise. |
is_cluster_running: bool
property
readonly
Returns if the local k3d cluster for this orchestrator is running.
For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.
Returns:
Type | Description |
---|---|
bool |
True if the local k3d cluster is running, False otherwise. |
is_daemon_running: bool
property
readonly
Returns if the local Kubeflow UI daemon for this orchestrator is running.
Returns:
Type | Description |
---|---|
bool |
True if the daemon is running, False otherwise. |
is_local: bool
property
readonly
Checks if the KFP orchestrator is running locally.
Returns:
Type | Description |
---|---|
bool |
|
is_provisioned: bool
property
readonly
Returns if a local k3d cluster for this orchestrator exists.
Returns:
Type | Description |
---|---|
bool |
True if a local k3d cluster exists, False otherwise. |
is_running: bool
property
readonly
Checks if the local k3d cluster and UI daemon are both running.
Returns:
Type | Description |
---|---|
bool |
True if the local k3d cluster and UI daemon for this orchestrator are both running. |
is_suspended: bool
property
readonly
Checks if the local k3d cluster and UI daemon are both stopped.
Returns:
Type | Description |
---|---|
bool |
True if the cluster and daemon for this orchestrator are both stopped, False otherwise. |
kubernetes_context: str
property
readonly
Gets the kubernetes context associated with the orchestrator.
This sets the default kubernetes_context
value to the value that is
used to create the locally managed k3d cluster, if not explicitly set.
Returns:
Type | Description |
---|---|
str |
The kubernetes context associated with the orchestrator. |
log_file: str
property
readonly
Path of the daemon log file.
Returns:
Type | Description |
---|---|
str |
Path of the daemon log file. |
pipeline_directory: str
property
readonly
Returns path to a directory in which the kubeflow pipeline files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Path to the root directory for all files concerning this orchestrator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Kubeflow orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Also check that requirements are met for local components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
deprovision(self)
Deprovisions a local Kubeflow Pipelines deployment.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def deprovision(self) -> None:
"""Deprovisions a local Kubeflow Pipelines deployment."""
if self.config.skip_cluster_provisioning:
return
if (
not self.config.skip_ui_daemon_provisioning
and self.is_daemon_running
):
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if self.is_local:
# don't deprovision any resources if using a remote KFP installation
local_deployment_utils.delete_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
logger.info("Local kubeflow pipelines deployment deprovisioned.")
if fileio.exists(self.log_file):
fileio.remove(self.log_file)
get_kubernetes_contexts(self)
Get the list of configured Kubernetes contexts and the active context.
Returns:
Type | Description |
---|---|
Tuple[List[str], Optional[str]] |
A tuple containing the list of configured Kubernetes contexts and the active context. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
"""Get the list of configured Kubernetes contexts and the active context.
Returns:
A tuple containing the list of configured Kubernetes contexts and
the active context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException:
return [], None
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
list_manual_setup_steps(self, container_registry_name, container_registry_path)
Logs manual steps needed to setup the Kubeflow local orchestrator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
container_registry_name |
str |
Name of the container registry. |
required |
container_registry_path |
str |
Path to the container registry. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def list_manual_setup_steps(
self, container_registry_name: str, container_registry_path: str
) -> None:
"""Logs manual steps needed to setup the Kubeflow local orchestrator.
Args:
container_registry_name: Name of the container registry.
container_registry_path: Path to the container registry.
"""
if not self.is_local:
# Make sure we're not telling users to deploy Kubeflow on their
# remote clusters
logger.warning(
"This Kubeflow orchestrator is configured to use a non-local "
f"Kubernetes context {self.kubernetes_context}. Manually "
f"deploying Kubeflow Pipelines is only possible for local "
f"Kubeflow orchestrators."
)
return
global_config_dir_path = io_utils.get_global_config_directory()
kubeflow_commands = [
f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.config.kubeflow_pipelines_ui_port}:80",
]
logger.info(
"If you wish to spin up this Kubeflow local orchestrator manually, "
"please enter the following commands:\n"
)
logger.info("\n".join(kubeflow_commands))
prepare_or_run_pipeline(self, deployment, stack)
Creates a kfp yaml file.
This functions as an intermediary representation of the pipeline which is then deployed to the kubeflow pipelines instance.
How it works:
Before this method is called the prepare_pipeline_deployment()
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
container_ops for each step (_construct_kfp_pipeline
).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these container_ops are then also configured onto each
container_op by pointing at the downstream steps.
This callable is then compiled into a kfp yaml file that is used as the intermediary representation of the kubeflow pipeline.
This file, together with some metadata, runtime configurations is then uploaded into the kubeflow pipelines cluster for execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If trying to run a pipeline in a notebook environment. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Creates a kfp yaml file.
This functions as an intermediary representation of the pipeline which
is then deployed to the kubeflow pipelines instance.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()`
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these container_ops are then also configured onto each
container_op by pointing at the downstream steps.
This callable is then compiled into a kfp yaml file that is used as
the intermediary representation of the kubeflow pipeline.
This file, together with some metadata, runtime configurations is
then uploaded into the kubeflow pipelines cluster for execution.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
RuntimeError: If trying to run a pipeline in a notebook environment.
"""
# First check whether the code running in a notebook
if Environment.in_notebook():
raise RuntimeError(
"The Kubeflow orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
is_scheduled_run = bool(deployment.schedule)
# Create a callable for future compilation into a dsl.Pipeline.
def _construct_kfp_pipeline() -> None:
"""Create a container_op for each step.
This should contain the name of the docker image and configures the
entrypoint of the docker image to run the step.
Additionally, this gives each container_op information about its
direct downstream steps.
If this callable is passed to the `_create_and_write_workflow()`
method of a KFPCompiler all dsl.ContainerOp instances will be
automatically added to a singular dsl.Pipeline instance.
"""
# Dictionary of container_ops index by the associated step name
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.steps.items():
# The command will be needed to eventually call the python step
# within the docker container
command = (
KubeflowEntrypointConfiguration.get_entrypoint_command()
)
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
arguments = (
KubeflowEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
**{METADATA_UI_PATH_OPTION: metadata_ui_path},
)
)
# Create a container_op - the kubeflow equivalent of a step. It
# contains the name of the step, the name of the docker image,
# the command to use to run the step entrypoint
# (e.g. `python -m zenml.entrypoints.step_entrypoint`)
# and the arguments to be passed along with the command. Find
# out more about how these arguments are parsed and used
# in the base entrypoint `run()` method.
container_op = dsl.ContainerOp(
name=step.config.name,
image=image_name,
command=command,
arguments=arguments,
output_artifact_paths={
"mlpipeline-ui-metadata": metadata_ui_path,
},
)
self._configure_container_op(
container_op=container_op, is_scheduled_run=is_scheduled_run
)
if self.requires_resources_in_orchestration_environment(step):
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
# Update dictionary of container ops with the current one
step_name_to_container_op[step.config.name] = container_op
# Get a filepath to use to save the finished yaml to
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{deployment.run_name}.yaml"
)
# write the argo pipeline yaml
KFPCompiler()._create_and_write_workflow(
pipeline_func=_construct_kfp_pipeline,
pipeline_name=deployment.pipeline.name,
package_path=pipeline_file_path,
)
# using the kfp client uploads the pipeline to kubeflow pipelines and
# runs it there
self._upload_and_run_pipeline(
deployment=deployment,
pipeline_file_path=pipeline_file_path,
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
provision(self)
Provisions a local Kubeflow Pipelines deployment.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the provisioning fails. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def provision(self) -> None:
"""Provisions a local Kubeflow Pipelines deployment.
Raises:
ProvisioningError: If the provisioning fails.
"""
if self.config.skip_cluster_provisioning:
return
if self.is_running:
logger.info(
"Found already existing local Kubeflow Pipelines deployment. "
"If there are any issues with the existing deployment, please "
"run 'zenml stack down --force' to delete it."
)
return
if not local_deployment_utils.check_prerequisites():
raise ProvisioningError(
"Unable to provision local Kubeflow Pipelines deployment: "
"Please install 'k3d' and 'kubectl' and try again."
)
container_registry = Client().active_stack.container_registry
# should not happen, because the stack validation takes care of this,
# but just in case
assert container_registry is not None
fileio.makedirs(self.root_directory)
if not self.is_local:
# don't provision any resources if using a remote KFP installation
return
logger.info("Provisioning local Kubeflow Pipelines deployment...")
container_registry_port = int(
container_registry.config.uri.split(":")[-1]
)
container_registry_name = self._get_k3d_registry_name(
port=container_registry_port
)
local_deployment_utils.write_local_registry_yaml(
yaml_path=self._k3d_registry_config_path,
registry_name=container_registry_name,
registry_uri=container_registry.config.uri,
)
try:
local_deployment_utils.create_k3d_cluster(
cluster_name=self._k3d_cluster_name,
registry_name=container_registry_name,
registry_config_path=self._k3d_registry_config_path,
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
local_deployment_utils.deploy_kubeflow_pipelines(
kubernetes_context=kubernetes_context
)
artifact_store = Client().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
kubernetes_context=kubernetes_context,
local_path=artifact_store.path,
)
except Exception as e:
logger.error(e)
logger.error(
"Unable to spin up local Kubeflow Pipelines deployment."
)
self.list_manual_setup_steps(
container_registry_name, self._k3d_registry_config_path
)
self.deprovision()
resume(self)
Resumes the local k3d cluster.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the k3d cluster is not provisioned. |
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def resume(self) -> None:
"""Resumes the local k3d cluster.
Raises:
ProvisioningError: If the k3d cluster is not provisioned.
"""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
if not self.is_provisioned:
raise ProvisioningError(
"Unable to resume local kubeflow pipelines deployment: No "
"resources provisioned for local deployment."
)
kubernetes_context = self.kubernetes_context
# will never happen, but mypy doesn't know that
assert kubernetes_context is not None
if (
not self.config.skip_cluster_provisioning
and self.is_local
and not self.is_cluster_running
):
# don't resume any resources if using a remote KFP installation
local_deployment_utils.start_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils.wait_until_kubeflow_pipelines_ready(
kubernetes_context=kubernetes_context
)
if not self.is_daemon_running:
local_deployment_utils.start_kfp_ui_daemon(
pid_file_path=self._pid_file_path,
log_file_path=self.log_file,
port=self._get_kfp_ui_daemon_port(),
kubernetes_context=kubernetes_context,
)
suspend(self)
Suspends the local k3d cluster.
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def suspend(self) -> None:
"""Suspends the local k3d cluster."""
if not self.is_provisioned:
logger.info("Local kubeflow pipelines deployment not provisioned.")
return
if (
not self.config.skip_ui_daemon_provisioning
and self.is_daemon_running
):
local_deployment_utils.stop_kfp_ui_daemon(
pid_file_path=self._pid_file_path
)
if (
not self.config.skip_cluster_provisioning
and self.is_local
and self.is_cluster_running
):
# don't suspend any resources if using a remote KFP installation
local_deployment_utils.stop_k3d_cluster(
cluster_name=self._k3d_cluster_name
)
local_deployment_utils
Utils for the local Kubeflow deployment behaviors.
add_hostpath_to_kubeflow_pipelines(kubernetes_context, local_path)
Patches the Kubeflow Pipelines deployment to mount a local folder.
This folder serves as a hostpath for visualization purposes.
This function reconfigures the Kubeflow pipelines deployment to use a shared local folder to support loading the TensorBoard viewer and other pipeline visualization results from a local artifact store, as described here:
https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context on which Kubeflow Pipelines should be patched. |
required |
local_path |
str |
The path to the local folder to mount as a hostpath. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def add_hostpath_to_kubeflow_pipelines(
kubernetes_context: str, local_path: str
) -> None:
"""Patches the Kubeflow Pipelines deployment to mount a local folder.
This folder serves as a hostpath for visualization purposes.
This function reconfigures the Kubeflow pipelines deployment to use a
shared local folder to support loading the TensorBoard viewer and other
pipeline visualization results from a local artifact store, as described
here:
https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md
Args:
kubernetes_context: The kubernetes context on which Kubeflow Pipelines
should be patched.
local_path: The path to the local folder to mount as a hostpath.
"""
logger.info("Patching Kubeflow Pipelines to mount a local folder.")
pod_template = {
"spec": {
"serviceAccountName": "kubeflow-pipelines-viewer",
"containers": [
{
"volumeMounts": [
{
"mountPath": local_path,
"name": "local-artifact-store",
}
]
}
],
"volumes": [
{
"hostPath": {
"path": local_path,
"type": "Directory",
},
"name": "local-artifact-store",
}
],
}
}
pod_template_json = json.dumps(pod_template, indent=2)
config_map_data = {"data": {"viewer-pod-template.json": pod_template_json}}
config_map_data_json = json.dumps(config_map_data, indent=2)
logger.debug(
"Adding host path volume for local path `%s` to kubeflow pipeline"
"viewer pod template configuration.",
local_path,
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"-n",
"kubeflow",
"patch",
"configmap/ml-pipeline-ui-configmap",
"--type",
"merge",
"-p",
config_map_data_json,
]
)
deployment_patch = {
"spec": {
"template": {
"spec": {
"containers": [
{
"name": "ml-pipeline-ui",
"volumeMounts": [
{
"mountPath": local_path,
"name": "local-artifact-store",
}
],
}
],
"volumes": [
{
"hostPath": {
"path": local_path,
"type": "Directory",
},
"name": "local-artifact-store",
}
],
}
}
}
}
deployment_patch_json = json.dumps(deployment_patch, indent=2)
logger.debug(
"Adding host path volume for local path `%s` to the kubeflow UI",
local_path,
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"-n",
"kubeflow",
"patch",
"deployment/ml-pipeline-ui",
"--type",
"strategic",
"-p",
deployment_patch_json,
]
)
wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
logger.info("Finished patching Kubeflow Pipelines setup.")
check_prerequisites(skip_k3d=False, skip_kubectl=False)
Checks prerequisites for a local kubeflow pipelines deployment.
It makes sure they are installed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_k3d |
bool |
Whether to skip the check for the k3d command. |
False |
skip_kubectl |
bool |
Whether to skip the check for the kubectl command. |
False |
Returns:
Type | Description |
---|---|
bool |
Whether all prerequisites are installed. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def check_prerequisites(
skip_k3d: bool = False, skip_kubectl: bool = False
) -> bool:
"""Checks prerequisites for a local kubeflow pipelines deployment.
It makes sure they are installed.
Args:
skip_k3d: Whether to skip the check for the k3d command.
skip_kubectl: Whether to skip the check for the kubectl command.
Returns:
Whether all prerequisites are installed.
"""
k3d_installed = skip_k3d or shutil.which("k3d") is not None
kubectl_installed = skip_kubectl or shutil.which("kubectl") is not None
logger.debug(
"Local kubeflow deployment prerequisites: K3D - %s, Kubectl - %s",
k3d_installed,
kubectl_installed,
)
return k3d_installed and kubectl_installed
create_k3d_cluster(cluster_name, registry_name, registry_config_path)
Creates a K3D cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to create. |
required |
registry_name |
str |
Name of the registry to create for this cluster. |
required |
registry_config_path |
str |
Path to the registry config file. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def create_k3d_cluster(
cluster_name: str, registry_name: str, registry_config_path: str
) -> None:
"""Creates a K3D cluster.
Args:
cluster_name: Name of the cluster to create.
registry_name: Name of the registry to create for this cluster.
registry_config_path: Path to the registry config file.
"""
logger.info("Creating local K3D cluster '%s'.", cluster_name)
local_stores_path = GlobalConfiguration().local_stores_path
subprocess.check_call(
[
"k3d",
"cluster",
"create",
cluster_name,
"--image",
K3S_IMAGE_NAME,
"--registry-create",
registry_name,
"--registry-config",
registry_config_path,
"--volume",
f"{local_stores_path}:{local_stores_path}",
]
)
logger.info("Finished K3D cluster creation.")
delete_k3d_cluster(cluster_name)
Deletes a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to delete. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def delete_k3d_cluster(cluster_name: str) -> None:
"""Deletes a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to delete.
"""
subprocess.check_call(["k3d", "cluster", "delete", cluster_name])
logger.info("Deleted local k3d cluster '%s'.", cluster_name)
deploy_kubeflow_pipelines(kubernetes_context)
Deploys Kubeflow Pipelines.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context on which Kubeflow Pipelines should be deployed. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def deploy_kubeflow_pipelines(kubernetes_context: str) -> None:
"""Deploys Kubeflow Pipelines.
Args:
kubernetes_context: The kubernetes context on which Kubeflow Pipelines
should be deployed.
"""
logger.info("Deploying Kubeflow Pipelines.")
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"apply",
"-k",
f"github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
]
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"wait",
"--timeout=60s",
"--for",
"condition=established",
"crd/applications.app.k8s.io",
]
)
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"apply",
"-k",
f"github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
]
)
wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
logger.info("Finished Kubeflow Pipelines setup.")
k3d_cluster_exists(cluster_name)
Checks whether there exists a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the cluster exists. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_exists(cluster_name: str) -> bool:
"""Checks whether there exists a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to check.
Returns:
Whether the cluster exists.
"""
output = subprocess.check_output(
["k3d", "cluster", "list", "--output", "json"]
)
clusters = json.loads(output)
for cluster in clusters:
if cluster["name"] == cluster_name:
return True
return False
k3d_cluster_running(cluster_name)
Checks whether the K3D cluster with the given name is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the cluster is running. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_running(cluster_name: str) -> bool:
"""Checks whether the K3D cluster with the given name is running.
Args:
cluster_name: Name of the cluster to check.
Returns:
Whether the cluster is running.
"""
output = subprocess.check_output(
["k3d", "cluster", "list", "--output", "json"]
)
clusters = json.loads(output)
for cluster in clusters:
if cluster["name"] == cluster_name:
server_count: int = cluster["serversCount"]
servers_running: int = cluster["serversRunning"]
return servers_running == server_count
return False
kubeflow_pipelines_ready(kubernetes_context)
Returns whether all Kubeflow Pipelines pods are ready.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context in which the pods should be checked. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether all Kubeflow Pipelines pods are ready. |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def kubeflow_pipelines_ready(kubernetes_context: str) -> bool:
"""Returns whether all Kubeflow Pipelines pods are ready.
Args:
kubernetes_context: The kubernetes context in which the pods
should be checked.
Returns:
Whether all Kubeflow Pipelines pods are ready.
"""
try:
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"wait",
"--for",
"condition=ready",
"--timeout=0s",
"pods",
"-l",
"application-crd-id=kubeflow-pipelines",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return True
except subprocess.CalledProcessError:
return False
start_k3d_cluster(cluster_name)
Starts a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to start. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_k3d_cluster(cluster_name: str) -> None:
"""Starts a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to start.
"""
subprocess.check_call(["k3d", "cluster", "start", cluster_name])
logger.info("Started local k3d cluster '%s'.", cluster_name)
start_kfp_ui_daemon(pid_file_path, log_file_path, port, kubernetes_context)
Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines UI is accessible in the browser.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file_path |
str |
Path where the file with the daemons process ID should be written. |
required |
log_file_path |
str |
Path to a file where the daemon logs should be written. |
required |
port |
int |
Port on which the UI should be accessible. |
required |
kubernetes_context |
str |
The kubernetes context for the cluster where Kubeflow Pipelines is running. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_kfp_ui_daemon(
pid_file_path: str,
log_file_path: str,
port: int,
kubernetes_context: str,
) -> None:
"""Starts a daemon process that forwards ports.
This is so the Kubeflow Pipelines UI is accessible in the browser.
Args:
pid_file_path: Path where the file with the daemons process ID should
be written.
log_file_path: Path to a file where the daemon logs should be written.
port: Port on which the UI should be accessible.
kubernetes_context: The kubernetes context for the cluster where
Kubeflow Pipelines is running.
"""
command = [
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"port-forward",
"svc/ml-pipeline-ui",
f"{port}:80",
]
if not networking_utils.port_available(port):
modified_command = command.copy()
modified_command[-1] = "PORT:80"
logger.warning(
"Unable to port-forward Kubeflow Pipelines UI to local port %d "
"because the port is occupied. In order to access the Kubeflow "
"Pipelines UI at http://localhost:PORT/, please run '%s' in a "
"separate command line shell (replace PORT with a free port of "
"your choice).",
port,
" ".join(modified_command),
)
elif sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Kubeflow Pipelines UI at "
"http://localhost:%d/, please run '%s' in a separate command "
"line shell.",
port,
" ".join(command),
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Port-forwards the Kubeflow Pipelines UI pod."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function, pid_file=pid_file_path, log_file=log_file_path
)
logger.info(
"Started Kubeflow Pipelines UI daemon (check the daemon logs at %s "
"in case you're not able to view the UI). The Kubeflow Pipelines "
"UI should now be accessible at http://localhost:%d/.",
log_file_path,
port,
)
stop_k3d_cluster(cluster_name)
Stops a K3D cluster with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cluster_name |
str |
Name of the cluster to stop. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_k3d_cluster(cluster_name: str) -> None:
"""Stops a K3D cluster with the given name.
Args:
cluster_name: Name of the cluster to stop.
"""
subprocess.check_call(["k3d", "cluster", "stop", cluster_name])
logger.info("Stopped local k3d cluster '%s'.", cluster_name)
stop_kfp_ui_daemon(pid_file_path)
Stops the KFP UI daemon process if it is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file_path |
str |
Path to the file with the daemons process ID. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_kfp_ui_daemon(pid_file_path: str) -> None:
"""Stops the KFP UI daemon process if it is running.
Args:
pid_file_path: Path to the file with the daemons process ID.
"""
if fileio.exists(pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(pid_file_path)
fileio.remove(pid_file_path)
logger.info("Stopped Kubeflow Pipelines UI daemon.")
wait_until_kubeflow_pipelines_ready(kubernetes_context)
Waits until all Kubeflow Pipelines pods are ready.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kubernetes_context |
str |
The kubernetes context in which the pods should be checked. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def wait_until_kubeflow_pipelines_ready(kubernetes_context: str) -> None:
"""Waits until all Kubeflow Pipelines pods are ready.
Args:
kubernetes_context: The kubernetes context in which the pods
should be checked.
"""
logger.info(
"Waiting for all Kubeflow Pipelines pods to be ready (this might "
"take a few minutes)."
)
while True:
logger.info("Current pod status:")
subprocess.check_call(
[
"kubectl",
"--context",
kubernetes_context,
"--namespace",
"kubeflow",
"get",
"pods",
]
)
if kubeflow_pipelines_ready(kubernetes_context=kubernetes_context):
break
logger.info("One or more pods not ready yet, waiting for 30 seconds...")
time.sleep(30)
write_local_registry_yaml(yaml_path, registry_name, registry_uri)
Writes a K3D registry config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_path |
str |
Path where the config file should be written to. |
required |
registry_name |
str |
Name of the registry. |
required |
registry_uri |
str |
URI of the registry. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def write_local_registry_yaml(
yaml_path: str, registry_name: str, registry_uri: str
) -> None:
"""Writes a K3D registry config file.
Args:
yaml_path: Path where the config file should be written to.
registry_name: Name of the registry.
registry_uri: URI of the registry.
"""
yaml_content = {
"mirrors": {registry_uri: {"endpoint": [f"http://{registry_name}"]}}
}
yaml_utils.write_yaml(yaml_path, yaml_content)
utils
Utils for ZenML Kubeflow orchestrators implementation.
dump_ui_metadata(execution_info, metadata_ui_path)
Dump KFP UI metadata json file for visualization purpose.
For general components we just render a simple Markdown file for exec_properties/inputs/outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
execution_info |
ExecutionInfo |
runtime execution info for this component, including materialized inputs/outputs/execution properties and id. |
required |
metadata_ui_path |
str |
path to dump ui metadata. |
required |
Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def dump_ui_metadata(
execution_info: data_types.ExecutionInfo,
metadata_ui_path: str,
) -> None:
"""Dump KFP UI metadata json file for visualization purpose.
For general components we just render a simple Markdown file for
exec_properties/inputs/outputs.
Args:
execution_info: runtime execution info for this component, including
materialized inputs/outputs/execution properties and id.
metadata_ui_path: path to dump ui metadata.
"""
node = execution_info.pipeline_node
if not node:
return
exec_properties_list = [
"**{}**: {}".format(
_sanitize_underscore(name), _sanitize_underscore(exec_property)
)
for name, exec_property in execution_info.exec_properties.items()
]
src_str_exec_properties = "# Execution properties:\n{}".format(
"\n\n".join(exec_properties_list) or "No execution property."
)
def _dump_input_populated_artifacts(
node_inputs: MutableMapping[str, InputSpec],
name_to_artifacts: Dict[str, List[artifact.Artifact]],
) -> List[str]:
"""Dump artifacts markdown string for inputs.
Args:
node_inputs: maps from input name to input sepc proto.
name_to_artifacts: maps from input key to list of populated artifacts.
Returns:
A list of dumped markdown string, each of which represents a channel.
"""
rendered_list = []
for name, spec in node_inputs.items():
# Need to look for materialized artifacts in the execution decision.
rendered_artifacts = "".join(
[
_render_artifact_as_mdstr(single_artifact)
for single_artifact in name_to_artifacts.get(name, [])
]
)
# There must be at least a channel in a input, and all channels in
# a input share the same artifact type.
artifact_type = spec.channels[0].artifact_query.type.name
rendered_list.append(
"## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
name=_sanitize_underscore(name),
channel_type=_sanitize_underscore(artifact_type),
artifacts=rendered_artifacts,
)
)
return rendered_list
def _dump_output_populated_artifacts(
node_outputs: MutableMapping[str, OutputSpec],
name_to_artifacts: Dict[str, List[artifact.Artifact]],
) -> List[str]:
"""Dump artifacts markdown string for outputs.
Args:
node_outputs: maps from output name to output sepc proto.
name_to_artifacts: maps from output key to list of populated
artifacts.
Returns:
A list of dumped markdown string, each of which represents a channel.
"""
rendered_list = []
for name, spec in node_outputs.items():
# Need to look for materialized artifacts in the execution decision.
rendered_artifacts = "".join(
[
_render_artifact_as_mdstr(single_artifact)
for single_artifact in name_to_artifacts.get(name, [])
]
)
# There must be at least a channel in a input, and all channels
# in a input share the same artifact type.
artifact_type = spec.artifact_spec.type.name
rendered_list.append(
"## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
name=_sanitize_underscore(name),
channel_type=_sanitize_underscore(artifact_type),
artifacts=rendered_artifacts,
)
)
return rendered_list
src_str_inputs = "# Inputs:\n{}".format(
"".join(
_dump_input_populated_artifacts(
node_inputs=node.inputs.inputs,
name_to_artifacts=execution_info.input_dict or {},
)
)
or "No input."
)
src_str_outputs = "# Outputs:\n{}".format(
"".join(
_dump_output_populated_artifacts(
node_outputs=node.outputs.outputs,
name_to_artifacts=execution_info.output_dict or {},
)
)
or "No output."
)
outputs = [
{
"storage": "inline",
"source": "{exec_properties}\n\n{inputs}\n\n{outputs}".format(
exec_properties=src_str_exec_properties,
inputs=src_str_inputs,
outputs=src_str_outputs,
),
"type": "markdown",
}
]
# Add TensorBoard view for ModelRun outputs.
for name, spec in node.outputs.outputs.items():
if (
spec.artifact_spec.type.name
== standard_artifacts.ModelRun.TYPE_NAME
or spec.artifact_spec.type.name == ModelArtifact.TYPE_NAME
):
output_model = execution_info.output_dict[name][0]
source = output_model.uri
# For local artifact repository, use a path that is relative to
# the point where the local artifact folder is mounted as a volume
artifact_store = Client().active_stack.artifact_store
if isinstance(artifact_store, LocalArtifactStore):
source = os.path.relpath(source, artifact_store.path)
source = f"volume://local-artifact-store/{source}"
# Add TensorBoard view.
tensorboard_output = {
"type": "tensorboard",
"source": source,
}
outputs.append(tensorboard_output)
metadata_dict = {"outputs": outputs}
with open(metadata_ui_path, "w") as f:
json.dump(metadata_dict, f)
mount_config_map_op(config_map_name)
Mounts all key-value pairs found in the named Kubernetes ConfigMap.
All key-value pairs in the ConfigMap are mounted as environment variables.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_map_name |
str |
The name of the ConfigMap resource. |
required |
Returns:
Type | Description |
---|---|
Callable[[kfp.dsl._container_op.ContainerOp], NoneType] |
An OpFunc for mounting the ConfigMap. |
Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def mount_config_map_op(
config_map_name: str,
) -> Callable[[dsl.ContainerOp], None]:
"""Mounts all key-value pairs found in the named Kubernetes ConfigMap.
All key-value pairs in the ConfigMap are mounted as environment variables.
Args:
config_map_name: The name of the ConfigMap resource.
Returns:
An OpFunc for mounting the ConfigMap.
"""
def mount_config_map(container_op: dsl.ContainerOp) -> None:
"""Mounts all key-value pairs found in the Kubernetes ConfigMap.
Args:
container_op: The container op to mount the ConfigMap.
"""
config_map_ref = k8s_client.V1ConfigMapEnvSource(
name=config_map_name, optional=True
)
container_op.container.add_env_from(
k8s_client.V1EnvFromSource(config_map_ref=config_map_ref)
)
return mount_config_map
kubernetes
special
Kubernetes integration for Kubernetes-native orchestration.
The Kubernetes integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubernetes orchestrator with the CLI tool.
KubernetesIntegration (Integration)
Definition of Kubernetes integration for ZenML.
Source code in zenml/integrations/kubernetes/__init__.py
class KubernetesIntegration(Integration):
"""Definition of Kubernetes integration for ZenML."""
NAME = KUBERNETES
REQUIREMENTS = ["kubernetes==18.20.0"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Kubernetes integration.
Returns:
List of new stack component flavors.
"""
from zenml.integrations.kubernetes.flavors import (
KubernetesOrchestratorFlavor,
)
return [KubernetesOrchestratorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Kubernetes integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of new stack component flavors. |
Source code in zenml/integrations/kubernetes/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Kubernetes integration.
Returns:
List of new stack component flavors.
"""
from zenml.integrations.kubernetes.flavors import (
KubernetesOrchestratorFlavor,
)
return [KubernetesOrchestratorFlavor]
flavors
special
Kubernetes integration flavors.
kubernetes_orchestrator_flavor
Kubernetes orchestrator flavor.
KubernetesOrchestratorConfig (BaseOrchestratorConfig)
pydantic-model
Configuration for the Kubernetes orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
kubernetes_context |
Optional[str] |
Optional name of a Kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running |
kubernetes_namespace |
str |
Name of the Kubernetes namespace to be used.
If not provided, |
synchronous |
bool |
If |
skip_config_loading |
bool |
If |
Source code in zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py
class KubernetesOrchestratorConfig(BaseOrchestratorConfig):
"""Configuration for the Kubernetes orchestrator.
Attributes:
kubernetes_context: Optional name of a Kubernetes context to run
pipelines in. If not set, the current active context will be used.
You can find the active context by running `kubectl config
current-context`.
kubernetes_namespace: Name of the Kubernetes namespace to be used.
If not provided, `zenml` namespace will be used.
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on Kubernetes.
skip_config_loading: If `True`, don't load the Kubernetes context and
clients. This is only useful for unit testing.
"""
kubernetes_context: Optional[str] = None
kubernetes_namespace: str = "zenml"
synchronous: bool = False
skip_config_loading: bool = False
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
KubernetesOrchestratorFlavor (BaseOrchestratorFlavor)
Kubernetes orchestrator flavor.
Source code in zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py
class KubernetesOrchestratorFlavor(BaseOrchestratorFlavor):
"""Kubernetes orchestrator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return KUBERNETES_ORCHESTRATOR_FLAVOR
@property
def config_class(self) -> Type[KubernetesOrchestratorConfig]:
"""Returns `KubernetesOrchestratorConfig` config class.
Returns:
The config class.
"""
return KubernetesOrchestratorConfig
@property
def implementation_class(self) -> Type["KubernetesOrchestrator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.kubernetes.orchestrators import (
KubernetesOrchestrator,
)
return KubernetesOrchestrator
config_class: Type[zenml.integrations.kubernetes.flavors.kubernetes_orchestrator_flavor.KubernetesOrchestratorConfig]
property
readonly
Returns KubernetesOrchestratorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.kubernetes.flavors.kubernetes_orchestrator_flavor.KubernetesOrchestratorConfig] |
The config class. |
implementation_class: Type[KubernetesOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[KubernetesOrchestrator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
orchestrators
special
Kubernetes-native orchestration.
dag_runner
DAG (Directed Acyclic Graph) Runners.
NodeStatus (Enum)
Status of the execution of a node.
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class NodeStatus(Enum):
"""Status of the execution of a node."""
WAITING = "Waiting"
RUNNING = "Running"
COMPLETED = "Completed"
ThreadedDagRunner
Multi-threaded DAG Runner.
This class expects a DAG of strings in adjacency list representation, as
well as a custom run_fn
as input, then calls run_fn(node)
for each
string node in the DAG.
Steps that can be executed in parallel will be started in separate threads.
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class ThreadedDagRunner:
"""Multi-threaded DAG Runner.
This class expects a DAG of strings in adjacency list representation, as
well as a custom `run_fn` as input, then calls `run_fn(node)` for each
string node in the DAG.
Steps that can be executed in parallel will be started in separate threads.
"""
def __init__(
self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
) -> None:
"""Define attributes and initialize all nodes in waiting state.
Args:
dag: Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
`dag={2: [1], 3: [1], 4: [2, 3]}`
run_fn: A function `run_fn(node)` that runs a single node
"""
self.dag = dag
self.reversed_dag = reverse_dag(dag)
self.run_fn = run_fn
self.nodes = dag.keys()
self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
self._lock = threading.Lock()
def _can_run(self, node: str) -> bool:
"""Determine whether a node is ready to be run.
This is the case if the node has not run yet and all of its upstream
node have already completed.
Args:
node: The node.
Returns:
True if the node can run else False.
"""
# Check that node has not run yet.
if not self.node_states[node] == NodeStatus.WAITING:
return False
# Check that all upstream nodes of this node have already completed.
for upstream_node in self.dag[node]:
if not self.node_states[upstream_node] == NodeStatus.COMPLETED:
return False
return True
def _run_node(self, node: str) -> None:
"""Run a single node.
Calls the user-defined run_fn, then calls `self._finish_node`.
Args:
node: The node.
"""
self.run_fn(node)
self._finish_node(node)
def _run_node_in_thread(self, node: str) -> threading.Thread:
"""Run a single node in a separate thread.
First updates the node status to running.
Then calls self._run_node() in a new thread and returns the thread.
Args:
node: The node.
Returns:
The thread in which the node was run.
"""
# Update node status to running.
assert self.node_states[node] == NodeStatus.WAITING
with self._lock:
self.node_states[node] = NodeStatus.RUNNING
# Run node in new thread.
thread = threading.Thread(target=self._run_node, args=(node,))
thread.start()
return thread
def _finish_node(self, node: str) -> None:
"""Finish a node run.
First updates the node status to completed.
Then starts all other nodes that can now be run and waits for them.
Args:
node: The node.
"""
# Update node status to completed.
assert self.node_states[node] == NodeStatus.RUNNING
with self._lock:
self.node_states[node] = NodeStatus.COMPLETED
# Run downstream nodes.
threads = []
for downstram_node in self.reversed_dag[node]:
if self._can_run(downstram_node):
thread = self._run_node_in_thread(downstram_node)
threads.append(thread)
# Wait for all downstream nodes to complete.
for thread in threads:
thread.join()
def run(self) -> None:
"""Call `self.run_fn` on all nodes in `self.dag`.
The order of execution is determined using topological sort.
Each node is run in a separate thread to enable parallelism.
"""
# Run all nodes that can be started immediately.
# These will, in turn, start other nodes once all of their respective
# upstream nodes have completed.
threads = []
for node in self.nodes:
if self._can_run(node):
thread = self._run_node_in_thread(node)
threads.append(thread)
# Wait till all nodes have completed.
for thread in threads:
thread.join()
# Make sure all nodes were run, otherwise print a warning.
for node in self.nodes:
if self.node_states[node] == NodeStatus.WAITING:
upstream_nodes = self.dag[node]
logger.warning(
f"Node `{node}` was never run, because it was still"
f" waiting for the following nodes: `{upstream_nodes}`."
)
__init__(self, dag, run_fn)
special
Define attributes and initialize all nodes in waiting state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag |
Dict[str, List[str]] |
Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
|
required |
run_fn |
Callable[[str], Any] |
A function |
required |
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def __init__(
self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
) -> None:
"""Define attributes and initialize all nodes in waiting state.
Args:
dag: Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
`dag={2: [1], 3: [1], 4: [2, 3]}`
run_fn: A function `run_fn(node)` that runs a single node
"""
self.dag = dag
self.reversed_dag = reverse_dag(dag)
self.run_fn = run_fn
self.nodes = dag.keys()
self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
self._lock = threading.Lock()
run(self)
Call self.run_fn
on all nodes in self.dag
.
The order of execution is determined using topological sort. Each node is run in a separate thread to enable parallelism.
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def run(self) -> None:
"""Call `self.run_fn` on all nodes in `self.dag`.
The order of execution is determined using topological sort.
Each node is run in a separate thread to enable parallelism.
"""
# Run all nodes that can be started immediately.
# These will, in turn, start other nodes once all of their respective
# upstream nodes have completed.
threads = []
for node in self.nodes:
if self._can_run(node):
thread = self._run_node_in_thread(node)
threads.append(thread)
# Wait till all nodes have completed.
for thread in threads:
thread.join()
# Make sure all nodes were run, otherwise print a warning.
for node in self.nodes:
if self.node_states[node] == NodeStatus.WAITING:
upstream_nodes = self.dag[node]
logger.warning(
f"Node `{node}` was never run, because it was still"
f" waiting for the following nodes: `{upstream_nodes}`."
)
reverse_dag(dag)
Reverse a DAG.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag |
Dict[str, List[str]] |
Adjacency list representation of a DAG. |
required |
Returns:
Type | Description |
---|---|
Dict[str, List[str]] |
Adjacency list representation of the reversed DAG. |
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def reverse_dag(dag: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Reverse a DAG.
Args:
dag: Adjacency list representation of a DAG.
Returns:
Adjacency list representation of the reversed DAG.
"""
reversed_dag = defaultdict(list)
# Reverse all edges in the graph.
for node, upstream_nodes in dag.items():
for upstream_node in upstream_nodes:
reversed_dag[upstream_node].append(node)
# Add nodes without incoming edges back in.
for node in dag:
if node not in reversed_dag:
reversed_dag[node] = []
return reversed_dag
kube_utils
Utilities for Kubernetes related functions.
Internal interface: no backwards compatibility guarantees. Adjusted from https://github.com/tensorflow/tfx/blob/master/tfx/utils/kube_utils.py.
PodPhase (Enum)
Phase of the Kubernetes pod.
Pod phases are defined in https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
class PodPhase(enum.Enum):
"""Phase of the Kubernetes pod.
Pod phases are defined in
https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
"""
PENDING = "Pending"
RUNNING = "Running"
SUCCEEDED = "Succeeded"
FAILED = "Failed"
UNKNOWN = "Unknown"
create_edit_service_account(core_api, rbac_api, service_account_name, namespace, cluster_role_binding_name='zenml-edit')
Create a new Kubernetes service account with "edit" rights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of Core V1 API of Kubernetes API. |
required |
rbac_api |
RbacAuthorizationV1Api |
Client of Rbac Authorization V1 API of Kubernetes API. |
required |
service_account_name |
str |
Name of the service account. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
required |
cluster_role_binding_name |
str |
Name of the cluster role binding. Defaults to "zenml-edit". |
'zenml-edit' |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_edit_service_account(
core_api: k8s_client.CoreV1Api,
rbac_api: k8s_client.RbacAuthorizationV1Api,
service_account_name: str,
namespace: str,
cluster_role_binding_name: str = "zenml-edit",
) -> None:
"""Create a new Kubernetes service account with "edit" rights.
Args:
core_api: Client of Core V1 API of Kubernetes API.
rbac_api: Client of Rbac Authorization V1 API of Kubernetes API.
service_account_name: Name of the service account.
namespace: Kubernetes namespace. Defaults to "default".
cluster_role_binding_name: Name of the cluster role binding.
Defaults to "zenml-edit".
"""
crb_manifest = build_cluster_role_binding_manifest_for_service_account(
name=cluster_role_binding_name,
role_name="edit",
service_account_name=service_account_name,
namespace=namespace,
)
_if_not_exists(rbac_api.create_cluster_role_binding)(body=crb_manifest)
sa_manifest = build_service_account_manifest(
name=service_account_name, namespace=namespace
)
_if_not_exists(core_api.create_namespaced_service_account)(
namespace=namespace,
body=sa_manifest,
)
create_mysql_deployment(core_api, apps_api, deployment_name, namespace, storage_capacity='10Gi', volume_name='mysql-pv-volume', volume_claim_name='mysql-pv-claim')
Create a Kubernetes deployment with a MySQL database running on it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of Core V1 API of Kubernetes API. |
required |
apps_api |
AppsV1Api |
Client of Apps V1 API of Kubernetes API. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
required |
storage_capacity |
str |
Storage capacity of the database.
Defaults to |
'10Gi' |
deployment_name |
str |
Name of the deployment. Defaults to "mysql". |
required |
volume_name |
str |
Name of the persistent volume.
Defaults to |
'mysql-pv-volume' |
volume_claim_name |
str |
Name of the persistent volume claim.
Defaults to |
'mysql-pv-claim' |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_mysql_deployment(
core_api: k8s_client.CoreV1Api,
apps_api: k8s_client.AppsV1Api,
deployment_name: str,
namespace: str,
storage_capacity: str = "10Gi",
volume_name: str = "mysql-pv-volume",
volume_claim_name: str = "mysql-pv-claim",
) -> None:
"""Create a Kubernetes deployment with a MySQL database running on it.
Args:
core_api: Client of Core V1 API of Kubernetes API.
apps_api: Client of Apps V1 API of Kubernetes API.
namespace: Kubernetes namespace. Defaults to "default".
storage_capacity: Storage capacity of the database.
Defaults to `"10Gi"`.
deployment_name: Name of the deployment. Defaults to "mysql".
volume_name: Name of the persistent volume.
Defaults to `"mysql-pv-volume"`.
volume_claim_name: Name of the persistent volume claim.
Defaults to `"mysql-pv-claim"`.
"""
pvc_manifest = build_persistent_volume_claim_manifest(
name=volume_claim_name,
namespace=namespace,
storage_request=storage_capacity,
)
_if_not_exists(core_api.create_namespaced_persistent_volume_claim)(
namespace=namespace,
body=pvc_manifest,
)
pv_manifest = build_persistent_volume_manifest(
name=volume_name, storage_capacity=storage_capacity
)
_if_not_exists(core_api.create_persistent_volume)(body=pv_manifest)
deployment_manifest = build_mysql_deployment_manifest(
name=deployment_name,
namespace=namespace,
pv_claim_name=volume_claim_name,
)
_if_not_exists(apps_api.create_namespaced_deployment)(
body=deployment_manifest, namespace=namespace
)
service_manifest = build_mysql_service_manifest(
name=deployment_name, namespace=namespace
)
_if_not_exists(core_api.create_namespaced_service)(
namespace=namespace, body=service_manifest
)
create_namespace(core_api, namespace)
Create a Kubernetes namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of Core V1 API of Kubernetes API. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
required |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_namespace(core_api: k8s_client.CoreV1Api, namespace: str) -> None:
"""Create a Kubernetes namespace.
Args:
core_api: Client of Core V1 API of Kubernetes API.
namespace: Kubernetes namespace. Defaults to "default".
"""
manifest = build_namespace_manifest(namespace)
_if_not_exists(core_api.create_namespace)(body=manifest)
delete_deployment(apps_api, deployment_name, namespace)
Delete a Kubernetes deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
apps_api |
AppsV1Api |
Client of Apps V1 API of Kubernetes API. |
required |
deployment_name |
str |
Name of the deployment to be deleted. |
required |
namespace |
str |
Kubernetes namespace containing the deployment. |
required |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def delete_deployment(
apps_api: k8s_client.AppsV1Api, deployment_name: str, namespace: str
) -> None:
"""Delete a Kubernetes deployment.
Args:
apps_api: Client of Apps V1 API of Kubernetes API.
deployment_name: Name of the deployment to be deleted.
namespace: Kubernetes namespace containing the deployment.
"""
options = k8s_client.V1DeleteOptions()
apps_api.delete_namespaced_deployment(
name=deployment_name,
namespace=namespace,
body=options,
propagation_policy="Foreground",
)
get_pod(core_api, pod_name, namespace)
Get a pod from Kubernetes metadata API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of |
required |
pod_name |
str |
The name of the pod. |
required |
namespace |
str |
The namespace of the pod. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
When it sees unexpected errors from Kubernetes API. |
Returns:
Type | Description |
---|---|
Optional[kubernetes.client.models.v1_pod.V1Pod] |
The found pod object. None if it's not found. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def get_pod(
core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str
) -> Optional[k8s_client.V1Pod]:
"""Get a pod from Kubernetes metadata API.
Args:
core_api: Client of `CoreV1Api` of Kubernetes API.
pod_name: The name of the pod.
namespace: The namespace of the pod.
Raises:
RuntimeError: When it sees unexpected errors from Kubernetes API.
Returns:
The found pod object. None if it's not found.
"""
try:
return core_api.read_namespaced_pod(name=pod_name, namespace=namespace)
except k8s_client.rest.ApiException as e:
if e.status == 404:
return None
raise RuntimeError from e
is_inside_kubernetes()
Check whether we are inside a Kubernetes cluster or on a remote host.
Returns:
Type | Description |
---|---|
bool |
True if inside a Kubernetes cluster, else False. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def is_inside_kubernetes() -> bool:
"""Check whether we are inside a Kubernetes cluster or on a remote host.
Returns:
True if inside a Kubernetes cluster, else False.
"""
try:
k8s_config.load_incluster_config()
return True
except k8s_config.ConfigException:
return False
load_kube_config(context=None)
Load the Kubernetes client config.
Depending on the environment (whether it is inside the running Kubernetes cluster or remote host), different location will be searched for the config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
Optional[str] |
Name of the Kubernetes context. If not provided, uses the currently active context. |
None |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def load_kube_config(context: Optional[str] = None) -> None:
"""Load the Kubernetes client config.
Depending on the environment (whether it is inside the running Kubernetes
cluster or remote host), different location will be searched for the config
file.
Args:
context: Name of the Kubernetes context. If not provided, uses the
currently active context.
"""
try:
k8s_config.load_incluster_config()
except k8s_config.ConfigException:
k8s_config.load_kube_config(context=context)
pod_failed(pod)
Check if pod status is 'Failed'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod |
V1Pod |
Kubernetes pod. |
required |
Returns:
Type | Description |
---|---|
bool |
True if pod status is 'Failed' else False. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_failed(pod: k8s_client.V1Pod) -> bool:
"""Check if pod status is 'Failed'.
Args:
pod: Kubernetes pod.
Returns:
True if pod status is 'Failed' else False.
"""
return pod.status.phase == PodPhase.FAILED.value # type: ignore[no-any-return]
pod_is_done(pod)
Check if pod status is 'Succeeded'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod |
V1Pod |
Kubernetes pod. |
required |
Returns:
Type | Description |
---|---|
bool |
True if pod status is 'Succeeded' else False. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_done(pod: k8s_client.V1Pod) -> bool:
"""Check if pod status is 'Succeeded'.
Args:
pod: Kubernetes pod.
Returns:
True if pod status is 'Succeeded' else False.
"""
return pod.status.phase == PodPhase.SUCCEEDED.value # type: ignore[no-any-return]
pod_is_not_pending(pod)
Check if pod status is not 'Pending'.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod |
V1Pod |
Kubernetes pod. |
required |
Returns:
Type | Description |
---|---|
bool |
False if the pod status is 'Pending' else True. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_not_pending(pod: k8s_client.V1Pod) -> bool:
"""Check if pod status is not 'Pending'.
Args:
pod: Kubernetes pod.
Returns:
False if the pod status is 'Pending' else True.
"""
return pod.status.phase != PodPhase.PENDING.value # type: ignore[no-any-return]
sanitize_pod_name(pod_name)
Sanitize pod names so they conform to Kubernetes pod naming convention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod_name |
str |
Arbitrary input pod name. |
required |
Returns:
Type | Description |
---|---|
str |
Sanitized pod name. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def sanitize_pod_name(pod_name: str) -> str:
"""Sanitize pod names so they conform to Kubernetes pod naming convention.
Args:
pod_name: Arbitrary input pod name.
Returns:
Sanitized pod name.
"""
pod_name = re.sub(r"[^a-z0-9-]", "-", pod_name.lower())
pod_name = re.sub(r"^[-]+", "", pod_name)
return re.sub(r"[-]+", "-", pod_name)
wait_pod(core_api, pod_name, namespace, exit_condition_lambda, timeout_sec=0, exponential_backoff=False, stream_logs=False)
Wait for a pod to meet an exit condition.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_api |
CoreV1Api |
Client of |
required |
pod_name |
str |
The name of the pod. |
required |
namespace |
str |
The namespace of the pod. |
required |
exit_condition_lambda |
Callable[[kubernetes.client.models.v1_pod.V1Pod], bool] |
A lambda which will be called periodically to wait for a pod to exit. The function returns True to exit. |
required |
timeout_sec |
int |
Timeout in seconds to wait for pod to reach exit condition, or 0 to wait for an unlimited duration. Defaults to unlimited. |
0 |
exponential_backoff |
bool |
Whether to use exponential back off for polling. Defaults to False. |
False |
stream_logs |
bool |
Whether to stream the pod logs to
|
False |
Exceptions:
Type | Description |
---|---|
RuntimeError |
when the function times out. |
Returns:
Type | Description |
---|---|
V1Pod |
The pod object which meets the exit condition. |
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def wait_pod(
core_api: k8s_client.CoreV1Api,
pod_name: str,
namespace: str,
exit_condition_lambda: Callable[[k8s_client.V1Pod], bool],
timeout_sec: int = 0,
exponential_backoff: bool = False,
stream_logs: bool = False,
) -> k8s_client.V1Pod:
"""Wait for a pod to meet an exit condition.
Args:
core_api: Client of `CoreV1Api` of Kubernetes API.
pod_name: The name of the pod.
namespace: The namespace of the pod.
exit_condition_lambda: A lambda
which will be called periodically to wait for a pod to exit. The
function returns True to exit.
timeout_sec: Timeout in seconds to wait for pod to reach exit
condition, or 0 to wait for an unlimited duration.
Defaults to unlimited.
exponential_backoff: Whether to use exponential back off for polling.
Defaults to False.
stream_logs: Whether to stream the pod logs to
`zenml.logger.info()`. Defaults to False.
Raises:
RuntimeError: when the function times out.
Returns:
The pod object which meets the exit condition.
"""
start_time = datetime.datetime.utcnow()
# Link to exponential back-off algorithm used here:
# https://cloud.google.com/storage/docs/exponential-backoff
backoff_interval = 1
maximum_backoff = 32
logged_lines = 0
while True:
resp = get_pod(core_api, pod_name, namespace)
# Stream logs to `zenml.logger.info()`.
# TODO: can we do this without parsing all logs every time?
if stream_logs and pod_is_not_pending(resp):
response = core_api.read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
)
logs = response.splitlines()
if len(logs) > logged_lines:
for line in logs[logged_lines:]:
logger.info(line)
logged_lines = len(logs)
# Raise an error if the pod failed.
if pod_failed(resp):
raise RuntimeError(f"Pod `{namespace}:{pod_name}` failed.")
# Check if pod is in desired state (e.g. finished / running / ...).
if exit_condition_lambda(resp):
return resp
# Check if wait timed out.
elapse_time = datetime.datetime.utcnow() - start_time
if elapse_time.seconds >= timeout_sec and timeout_sec != 0:
raise RuntimeError(
f"Waiting for pod `{namespace}:{pod_name}` timed out after "
f"{timeout_sec} seconds."
)
# Wait (using exponential backoff).
time.sleep(backoff_interval)
if exponential_backoff and backoff_interval < maximum_backoff:
backoff_interval *= 2
kubernetes_orchestrator
Kubernetes-native orchestrator.
KubernetesOrchestrator (BaseOrchestrator)
Orchestrator for running ZenML pipelines using native Kubernetes.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
class KubernetesOrchestrator(BaseOrchestrator):
"""Orchestrator for running ZenML pipelines using native Kubernetes."""
_k8s_core_api: k8s_client.CoreV1Api = None
_k8s_batch_api: k8s_client.BatchV1beta1Api = None
_k8s_rbac_api: k8s_client.RbacAuthorizationV1Api = None
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the class and the Kubernetes clients.
Args:
*args: The positional arguments to pass to the Pydantic object.
**kwargs: The keyword arguments to pass to the Pydantic object.
"""
super().__init__(*args, **kwargs)
self._initialize_k8s_clients()
def _initialize_k8s_clients(self) -> None:
"""Initialize the Kubernetes clients."""
if self.config.skip_config_loading:
return
kube_utils.load_kube_config(context=self.config.kubernetes_context)
self._k8s_core_api = k8s_client.CoreV1Api()
self._k8s_batch_api = k8s_client.BatchV1beta1Api()
self._k8s_rbac_api = k8s_client.RbacAuthorizationV1Api()
@property
def config(self) -> KubernetesOrchestratorConfig:
"""Returns the `KubernetesOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(KubernetesOrchestratorConfig, self._config)
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
"""Get list of configured Kubernetes contexts and the active context.
Raises:
RuntimeError: if the Kubernetes configuration cannot be loaded.
Returns:
context_name: List of configured Kubernetes contexts
active_context_name: Name of the active Kubernetes context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException as e:
raise RuntimeError(
"Could not load the Kubernetes configuration"
) from e
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
@property
def validator(self) -> Optional[StackValidator]:
"""Defines the validator that checks whether the stack is valid.
Returns:
Stack validator.
"""
def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
"""Validates that the stack contains no local components.
Args:
stack: The stack.
Returns:
Whether the stack is valid or not.
An explanation why the stack is invalid, if applicable.
"""
container_registry = stack.container_registry
# should not happen, because the stack validation takes care of
# this, but just in case
assert container_registry is not None
if not self.config.skip_config_loading:
contexts, active_context = self.get_kubernetes_contexts()
if self.config.kubernetes_context not in contexts:
return False, (
f"Could not find a Kubernetes context named "
f"'{self.config.kubernetes_context}' in the local Kubernetes "
f"configuration. Please make sure that the Kubernetes "
f"cluster is running and that the kubeconfig file is "
f"configured correctly. To list all configured "
f"contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
)
if self.config.kubernetes_context != active_context:
logger.warning(
f"The Kubernetes context '{self.config.kubernetes_context}' "
f"configured for the Kubernetes orchestrator is not "
f"the same as the active context in the local "
f"Kubernetes configuration. If this is not deliberate,"
f" you should update the orchestrator's "
f"`kubernetes_context` field by running:\n\n"
f" `zenml orchestrator update {self.name} "
f"--kubernetes_context={active_context}`\n"
f"To list all configured contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
f"To set the active context to be the same as the one "
f"configured in the Kubernetes orchestrator and "
f"silence this warning, run:\n\n"
f" `kubectl config use-context "
f"{self.config.kubernetes_context}`\n"
)
# Check that all stack components are non-local.
for stack_component in stack.components.values():
if stack_component.local_path is not None:
return False, (
f"The Kubernetes orchestrator currently only supports "
f"remote stacks, but the '{stack_component.name}' "
f"{stack_component.type.value} is a local component. "
f"Please make sure to only use non-local stack "
f"components with a Kubernetes orchestrator."
)
# if the orchestrator is remote, the container registry must
# also be remote.
if container_registry.config.is_local:
return False, (
f"The Kubernetes orchestrator requires a remote container "
f"registry, but the '{container_registry.name}' container "
f"registry of your active stack points to a local URI "
f"'{container_registry.config.uri}'. Please make sure "
f"stacks with a Kubernetes orchestrator always contain "
f"remote container registries."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_local_requirements,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Runs the pipeline in Kubernetes.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
RuntimeError: If trying to run from a Jupyter notebook.
"""
# First check whether the code is running in a notebook.
if Environment.in_notebook():
raise RuntimeError(
"The Kubernetes orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
for step in deployment.steps.values():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Kubernetes orchestrator, ignoring resource "
"configuration for step %s.",
step.config.name,
)
run_name = deployment.run_name
pipeline_name = deployment.pipeline.name
pod_name = kube_utils.sanitize_pod_name(run_name)
# Get Docker image name (for all pods).
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
# Build entrypoint command and args for the orchestrator pod.
# This will internally also build the command/args for all step pods.
command = (
KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
)
args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
run_name=run_name,
image_name=image_name,
kubernetes_namespace=self.config.kubernetes_namespace,
)
# Authorize pod to run Kubernetes commands inside the cluster.
service_account_name = "zenml-service-account"
kube_utils.create_edit_service_account(
core_api=self._k8s_core_api,
rbac_api=self._k8s_rbac_api,
service_account_name=service_account_name,
namespace=self.config.kubernetes_namespace,
)
# Schedule as CRON job if CRON schedule is given.
if deployment.schedule:
if not deployment.schedule.cron_expression:
raise RuntimeError(
"The Kubernetes orchestrator only supports scheduling via "
"CRON jobs, but the run was configured with a manual "
"schedule. Use `Schedule(cron_expression=...)` instead."
)
cron_expression = deployment.schedule.cron_expression
cron_job_manifest = build_cron_job_manifest(
cron_expression=cron_expression,
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_batch_api.create_namespaced_cron_job(
body=cron_job_manifest,
namespace=self.config.kubernetes_namespace,
)
logger.info(
f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
f'`"{cron_expression}"`.'
)
return
# Create and run the orchestrator pod.
pod_manifest = build_pod_manifest(
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_core_api.create_namespaced_pod(
namespace=self.config.kubernetes_namespace,
body=pod_manifest,
)
# Wait for the orchestrator pod to finish and stream logs.
if self.config.synchronous:
logger.info("Waiting for Kubernetes orchestrator pod...")
kube_utils.wait_pod(
core_api=self._k8s_core_api,
pod_name=pod_name,
namespace=self.config.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_done,
stream_logs=True,
)
else:
logger.info(
f"Orchestration started asynchronously in pod "
f"`{self.config.kubernetes_namespace}:{pod_name}`. "
f"Run the following command to inspect the logs: "
f"`kubectl logs {pod_name} -n {self.config.kubernetes_namespace}`."
)
config: KubernetesOrchestratorConfig
property
readonly
Returns the KubernetesOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
KubernetesOrchestratorConfig |
The configuration. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Defines the validator that checks whether the stack is valid.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
Stack validator. |
__init__(self, *args, **kwargs)
special
Initialize the class and the Kubernetes clients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
The positional arguments to pass to the Pydantic object. |
() |
**kwargs |
Any |
The keyword arguments to pass to the Pydantic object. |
{} |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the class and the Kubernetes clients.
Args:
*args: The positional arguments to pass to the Pydantic object.
**kwargs: The keyword arguments to pass to the Pydantic object.
"""
super().__init__(*args, **kwargs)
self._initialize_k8s_clients()
get_kubernetes_contexts(self)
Get list of configured Kubernetes contexts and the active context.
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the Kubernetes configuration cannot be loaded. |
Returns:
Type | Description |
---|---|
context_name |
List of configured Kubernetes contexts active_context_name: Name of the active Kubernetes context. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
"""Get list of configured Kubernetes contexts and the active context.
Raises:
RuntimeError: if the Kubernetes configuration cannot be loaded.
Returns:
context_name: List of configured Kubernetes contexts
active_context_name: Name of the active Kubernetes context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException as e:
raise RuntimeError(
"Could not load the Kubernetes configuration"
) from e
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
prepare_or_run_pipeline(self, deployment, stack)
Runs the pipeline in Kubernetes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If trying to run from a Jupyter notebook. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Runs the pipeline in Kubernetes.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
RuntimeError: If trying to run from a Jupyter notebook.
"""
# First check whether the code is running in a notebook.
if Environment.in_notebook():
raise RuntimeError(
"The Kubernetes orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
for step in deployment.steps.values():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Kubernetes orchestrator, ignoring resource "
"configuration for step %s.",
step.config.name,
)
run_name = deployment.run_name
pipeline_name = deployment.pipeline.name
pod_name = kube_utils.sanitize_pod_name(run_name)
# Get Docker image name (for all pods).
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
# Build entrypoint command and args for the orchestrator pod.
# This will internally also build the command/args for all step pods.
command = (
KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
)
args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
run_name=run_name,
image_name=image_name,
kubernetes_namespace=self.config.kubernetes_namespace,
)
# Authorize pod to run Kubernetes commands inside the cluster.
service_account_name = "zenml-service-account"
kube_utils.create_edit_service_account(
core_api=self._k8s_core_api,
rbac_api=self._k8s_rbac_api,
service_account_name=service_account_name,
namespace=self.config.kubernetes_namespace,
)
# Schedule as CRON job if CRON schedule is given.
if deployment.schedule:
if not deployment.schedule.cron_expression:
raise RuntimeError(
"The Kubernetes orchestrator only supports scheduling via "
"CRON jobs, but the run was configured with a manual "
"schedule. Use `Schedule(cron_expression=...)` instead."
)
cron_expression = deployment.schedule.cron_expression
cron_job_manifest = build_cron_job_manifest(
cron_expression=cron_expression,
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_batch_api.create_namespaced_cron_job(
body=cron_job_manifest,
namespace=self.config.kubernetes_namespace,
)
logger.info(
f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
f'`"{cron_expression}"`.'
)
return
# Create and run the orchestrator pod.
pod_manifest = build_pod_manifest(
run_name=run_name,
pod_name=pod_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
self._k8s_core_api.create_namespaced_pod(
namespace=self.config.kubernetes_namespace,
body=pod_manifest,
)
# Wait for the orchestrator pod to finish and stream logs.
if self.config.synchronous:
logger.info("Waiting for Kubernetes orchestrator pod...")
kube_utils.wait_pod(
core_api=self._k8s_core_api,
pod_name=pod_name,
namespace=self.config.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_done,
stream_logs=True,
)
else:
logger.info(
f"Orchestration started asynchronously in pod "
f"`{self.config.kubernetes_namespace}:{pod_name}`. "
f"Run the following command to inspect the logs: "
f"`kubectl logs {pod_name} -n {self.config.kubernetes_namespace}`."
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
kubernetes_orchestrator_entrypoint
Entrypoint of the Kubernetes master/orchestrator pod.
main()
Entrypoint of the k8s master/orchestrator pod.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def main() -> None:
"""Entrypoint of the k8s master/orchestrator pod."""
# Log to the container's stdout so it can be streamed by the client.
logger.info("Kubernetes orchestrator pod started.")
# Parse / extract args.
args = parse_args()
# Get Kubernetes Core API for running kubectl commands later.
kube_utils.load_kube_config()
core_api = k8s_client.CoreV1Api()
# Patch run name (only needed for CRON scheduling)
run_name = patch_run_name_for_cron_scheduling(args.run_name)
config_dict = yaml_utils.read_yaml(DOCKER_IMAGE_DEPLOYMENT_CONFIG_FILE)
deployment_config = PipelineDeployment.parse_obj(config_dict)
pipeline_dag = {}
step_name_to_pipeline_step_name = {}
for name_in_pipeline, step in deployment_config.steps.items():
step_name_to_pipeline_step_name[step.config.name] = name_in_pipeline
pipeline_dag[step.config.name] = step.spec.upstream_steps
step_command = (
KubernetesStepEntrypointConfiguration.get_entrypoint_command()
)
def run_step_on_kubernetes(step_name: str) -> None:
"""Run a pipeline step in a separate Kubernetes pod.
Args:
step_name: Name of the step.
"""
# Define Kubernetes pod name.
pod_name = f"{run_name}-{step_name}"
pod_name = kube_utils.sanitize_pod_name(pod_name)
pipeline_step_name = step_name_to_pipeline_step_name[step_name]
step_args = (
KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
step_name=pipeline_step_name, run_name=run_name
)
)
# Define Kubernetes pod manifest.
pod_manifest = build_pod_manifest(
pod_name=pod_name,
run_name=run_name,
pipeline_name=deployment_config.pipeline.name,
image_name=args.image_name,
command=step_command,
args=step_args,
)
# Create and run pod.
core_api.create_namespaced_pod(
namespace=args.kubernetes_namespace,
body=pod_manifest,
)
# Wait for pod to finish.
logger.info(f"Waiting for pod of step `{step_name}` to start...")
kube_utils.wait_pod(
core_api=core_api,
pod_name=pod_name,
namespace=args.kubernetes_namespace,
exit_condition_lambda=kube_utils.pod_is_done,
stream_logs=True,
)
logger.info(f"Pod of step `{step_name}` completed.")
ThreadedDagRunner(dag=pipeline_dag, run_fn=run_step_on_kubernetes).run()
logger.info("Orchestration pod completed.")
parse_args()
Parse entrypoint arguments.
Returns:
Type | Description |
---|---|
Namespace |
Parsed args. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def parse_args() -> argparse.Namespace:
"""Parse entrypoint arguments.
Returns:
Parsed args.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type=str, required=True)
parser.add_argument("--image_name", type=str, required=True)
parser.add_argument("--kubernetes_namespace", type=str, required=True)
return parser.parse_args()
patch_run_name_for_cron_scheduling(run_name)
Adjust run name according to the Kubernetes orchestrator pod name.
This is required for scheduling via CRON jobs, since each job would otherwise have the same run name, which zenml does not support.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Initial run name. |
required |
Returns:
Type | Description |
---|---|
str |
New unique run name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def patch_run_name_for_cron_scheduling(run_name: str) -> str:
"""Adjust run name according to the Kubernetes orchestrator pod name.
This is required for scheduling via CRON jobs, since each job would
otherwise have the same run name, which zenml does not support.
Args:
run_name: Initial run name.
Returns:
New unique run name.
"""
# Get name of the orchestrator pod.
host_name = socket.gethostname()
# If we are not running as CRON job, we don't need to do anything.
if host_name == kube_utils.sanitize_pod_name(run_name):
return run_name
# Otherwise, define new run_name.
job_id = host_name.split("-")[-1]
run_name = f"{run_name}-{job_id}"
return run_name
kubernetes_orchestrator_entrypoint_configuration
Entrypoint configuration for the Kubernetes master/orchestrator pod.
KubernetesOrchestratorEntrypointConfiguration
Entrypoint configuration for the k8s master/orchestrator pod.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
class KubernetesOrchestratorEntrypointConfiguration:
"""Entrypoint configuration for the k8s master/orchestrator pod."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all the options required for running this entrypoint.
Returns:
Entrypoint options.
"""
options = {
RUN_NAME_OPTION,
IMAGE_NAME_OPTION,
NAMESPACE_OPTION,
}
return options
@classmethod
def get_entrypoint_command(cls) -> List[str]:
"""Returns a command that runs the entrypoint module.
Returns:
Entrypoint command.
"""
command = [
"python",
"-m",
"zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
]
return command
@classmethod
def get_entrypoint_arguments(
cls,
run_name: str,
image_name: str,
kubernetes_namespace: str,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
run_name: Name of the ZenML run.
image_name: Name of the Docker image.
kubernetes_namespace: Name of the Kubernetes namespace.
Returns:
List of entrypoint arguments.
"""
args = [
f"--{RUN_NAME_OPTION}",
run_name,
f"--{IMAGE_NAME_OPTION}",
image_name,
f"--{NAMESPACE_OPTION}",
kubernetes_namespace,
]
return args
get_entrypoint_arguments(run_name, image_name, kubernetes_namespace)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the ZenML run. |
required |
image_name |
str |
Name of the Docker image. |
required |
kubernetes_namespace |
str |
Name of the Kubernetes namespace. |
required |
Returns:
Type | Description |
---|---|
List[str] |
List of entrypoint arguments. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
cls,
run_name: str,
image_name: str,
kubernetes_namespace: str,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
run_name: Name of the ZenML run.
image_name: Name of the Docker image.
kubernetes_namespace: Name of the Kubernetes namespace.
Returns:
List of entrypoint arguments.
"""
args = [
f"--{RUN_NAME_OPTION}",
run_name,
f"--{IMAGE_NAME_OPTION}",
image_name,
f"--{NAMESPACE_OPTION}",
kubernetes_namespace,
]
return args
get_entrypoint_command()
classmethod
Returns a command that runs the entrypoint module.
Returns:
Type | Description |
---|---|
List[str] |
Entrypoint command. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_command(cls) -> List[str]:
"""Returns a command that runs the entrypoint module.
Returns:
Entrypoint command.
"""
command = [
"python",
"-m",
"zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
]
return command
get_entrypoint_options()
classmethod
Gets all the options required for running this entrypoint.
Returns:
Type | Description |
---|---|
Set[str] |
Entrypoint options. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all the options required for running this entrypoint.
Returns:
Entrypoint options.
"""
options = {
RUN_NAME_OPTION,
IMAGE_NAME_OPTION,
NAMESPACE_OPTION,
}
return options
kubernetes_step_entrypoint_configuration
Entrypoint configuration for the Kubernetes worker/step pods.
KubernetesStepEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on Kubernetes.
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
class KubernetesStepEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on Kubernetes."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the run name.
"""
return super().get_entrypoint_options() | {RUN_NAME_OPTION}
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the run name.
Returns:
The superclass arguments as well as arguments for the run name.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{RUN_NAME_OPTION}",
kwargs[RUN_NAME_OPTION],
]
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the ZenML run name.
Args:
pipeline_name: Name of the ZenML pipeline (unused).
Returns:
ZenML run name.
"""
job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
return job_id
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs, must include the run name. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the run name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the run name.
Returns:
The superclass arguments as well as arguments for the run name.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{RUN_NAME_OPTION}",
kwargs[RUN_NAME_OPTION],
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the run name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the run name.
"""
return super().get_entrypoint_options() | {RUN_NAME_OPTION}
get_run_name(self, pipeline_name)
Returns the ZenML run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the ZenML pipeline (unused). |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
ZenML run name. |
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the ZenML run name.
Args:
pipeline_name: Name of the ZenML pipeline (unused).
Returns:
ZenML run name.
"""
job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
return job_id
manifest_utils
Utility functions for building manifests for k8s pods.
build_cluster_role_binding_manifest_for_service_account(name, role_name, service_account_name, namespace='default')
Build a manifest for a cluster role binding of a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the cluster role binding. |
required |
role_name |
str |
Name of the role. |
required |
service_account_name |
str |
Name of the service account. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a cluster role binding of a service account. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cluster_role_binding_manifest_for_service_account(
name: str,
role_name: str,
service_account_name: str,
namespace: str = "default",
) -> Dict[str, Any]:
"""Build a manifest for a cluster role binding of a service account.
Args:
name: Name of the cluster role binding.
role_name: Name of the role.
service_account_name: Name of the service account.
namespace: Kubernetes namespace. Defaults to "default".
Returns:
Manifest for a cluster role binding of a service account.
"""
return {
"apiVersion": "rbac.authorization.k8s.io/v1",
"kind": "ClusterRoleBinding",
"metadata": {"name": name},
"subjects": [
{
"kind": "ServiceAccount",
"name": service_account_name,
"namespace": namespace,
}
],
"roleRef": {
"kind": "ClusterRole",
"name": role_name,
"apiGroup": "rbac.authorization.k8s.io",
},
}
build_cron_job_manifest(cron_expression, pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)
Create a manifest for launching a pod as scheduled CRON job.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cron_expression |
str |
CRON job schedule expression, e.g. " * * *". |
required |
pod_name |
str |
Name of the pod. |
required |
run_name |
str |
Name of the ZenML run. |
required |
pipeline_name |
str |
Name of the ZenML pipeline. |
required |
image_name |
str |
Name of the Docker image. |
required |
command |
List[str] |
Command to execute the entrypoint in the pod. |
required |
args |
List[str] |
Arguments provided to the entrypoint command. |
required |
service_account_name |
Optional[str] |
Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster. |
None |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
CRON job manifest. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cron_job_manifest(
cron_expression: str,
pod_name: str,
run_name: str,
pipeline_name: str,
image_name: str,
command: List[str],
args: List[str],
service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Create a manifest for launching a pod as scheduled CRON job.
Args:
cron_expression: CRON job schedule expression, e.g. "* * * * *".
pod_name: Name of the pod.
run_name: Name of the ZenML run.
pipeline_name: Name of the ZenML pipeline.
image_name: Name of the Docker image.
command: Command to execute the entrypoint in the pod.
args: Arguments provided to the entrypoint command.
service_account_name: Optional name of a service account.
Can be used to assign certain roles to a pod, e.g., to allow it to
run Kubernetes commands from within the cluster.
Returns:
CRON job manifest.
"""
pod_manifest = build_pod_manifest(
pod_name=pod_name,
run_name=run_name,
pipeline_name=pipeline_name,
image_name=image_name,
command=command,
args=args,
service_account_name=service_account_name,
)
return {
"apiVersion": "batch/v1beta1",
"kind": "CronJob",
"metadata": pod_manifest["metadata"],
"spec": {
"schedule": cron_expression,
"jobTemplate": {
"metadata": pod_manifest["metadata"],
"spec": {"template": {"spec": pod_manifest["spec"]}},
},
},
}
build_mysql_deployment_manifest(name='mysql', namespace='default', port=3306, pv_claim_name='mysql-pv-claim')
Build a manifest for deploying a MySQL database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the deployment. Defaults to "mysql". |
'mysql' |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
port |
int |
Port where MySQL is running. Defaults to 3306. |
3306 |
pv_claim_name |
str |
Name of the required persistent volume claim.
Defaults to |
'mysql-pv-claim' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for deploying a MySQL database. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_deployment_manifest(
name: str = "mysql",
namespace: str = "default",
port: int = 3306,
pv_claim_name: str = "mysql-pv-claim",
) -> Dict[str, Any]:
"""Build a manifest for deploying a MySQL database.
Args:
name: Name of the deployment. Defaults to "mysql".
namespace: Kubernetes namespace. Defaults to "default".
port: Port where MySQL is running. Defaults to 3306.
pv_claim_name: Name of the required persistent volume claim.
Defaults to `"mysql-pv-claim"`.
Returns:
Manifest for deploying a MySQL database.
"""
return {
"apiVersion": "apps/v1",
"kind": "Deployment",
"metadata": {"name": name, "namespace": namespace},
"spec": {
"selector": {
"matchLabels": {
"app": name,
},
},
"strategy": {
"type": "Recreate",
},
"template": {
"metadata": {
"labels": {"app": name},
},
"spec": {
"containers": [
{
"image": "gcr.io/ml-pipeline/mysql:5.6",
"name": name,
"env": [
{
"name": "MYSQL_ALLOW_EMPTY_PASSWORD",
"value": '"true"',
}
],
"ports": [{"containerPort": port, "name": name}],
"volumeMounts": [
{
"name": "mysql-persistent-storage",
"mountPath": "/var/lib/mysql",
}
],
}
],
"volumes": [
{
"name": "mysql-persistent-storage",
"persistentVolumeClaim": {
"claimName": pv_claim_name
},
}
],
},
},
},
}
build_mysql_service_manifest(name='mysql', namespace='default', port=3306)
Build a manifest for a service relating to a deployed MySQL database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the service. Defaults to "mysql". |
'mysql' |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
port |
int |
Port where MySQL is running. Defaults to 3306. |
3306 |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for the MySQL service. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_service_manifest(
name: str = "mysql",
namespace: str = "default",
port: int = 3306,
) -> Dict[str, Any]:
"""Build a manifest for a service relating to a deployed MySQL database.
Args:
name: Name of the service. Defaults to "mysql".
namespace: Kubernetes namespace. Defaults to "default".
port: Port where MySQL is running. Defaults to 3306.
Returns:
Manifest for the MySQL service.
"""
return {
"apiVersion": "v1",
"kind": "Service",
"metadata": {
"name": name,
"namespace": namespace,
},
"spec": {
"selector": {"app": "mysql"},
"clusterIP": "None",
"ports": [{"port": port}],
},
}
build_namespace_manifest(namespace)
Build the manifest for a new namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
namespace |
str |
Kubernetes namespace. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest of the new namespace. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_namespace_manifest(namespace: str) -> Dict[str, Any]:
"""Build the manifest for a new namespace.
Args:
namespace: Kubernetes namespace.
Returns:
Manifest of the new namespace.
"""
return {
"apiVersion": "v1",
"kind": "Namespace",
"metadata": {
"name": namespace,
},
}
build_persistent_volume_claim_manifest(name, namespace='default', storage_request='10Gi')
Build a manifest for a persistent volume claim.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the persistent volume claim. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
storage_request |
str |
Size of the storage to request. Defaults to |
'10Gi' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a persistent volume claim. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_claim_manifest(
name: str,
namespace: str = "default",
storage_request: str = "10Gi",
) -> Dict[str, Any]:
"""Build a manifest for a persistent volume claim.
Args:
name: Name of the persistent volume claim.
namespace: Kubernetes namespace. Defaults to "default".
storage_request: Size of the storage to request. Defaults to `"10Gi"`.
Returns:
Manifest for a persistent volume claim.
"""
return {
"apiVersion": "v1",
"kind": "PersistentVolumeClaim",
"metadata": {
"name": name,
"namespace": namespace,
},
"spec": {
"storageClassName": "manual",
"accessModes": ["ReadWriteOnce"],
"resources": {
"requests": {
"storage": storage_request,
}
},
},
}
build_persistent_volume_manifest(name, namespace='default', storage_capacity='10Gi', path='/mnt/data')
Build a manifest for a persistent volume.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the persistent volume. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
storage_capacity |
str |
Storage capacity of the volume. Defaults to |
'10Gi' |
path |
str |
Path where the volume is mounted. Defaults to |
'/mnt/data' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a persistent volume. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_manifest(
name: str,
namespace: str = "default",
storage_capacity: str = "10Gi",
path: str = "/mnt/data",
) -> Dict[str, Any]:
"""Build a manifest for a persistent volume.
Args:
name: Name of the persistent volume.
namespace: Kubernetes namespace. Defaults to "default".
storage_capacity: Storage capacity of the volume. Defaults to `"10Gi"`.
path: Path where the volume is mounted. Defaults to `"/mnt/data"`.
Returns:
Manifest for a persistent volume.
"""
return {
"apiVersion": "v1",
"kind": "PersistentVolume",
"metadata": {
"name": name,
"namespace": namespace,
"labels": {"type": "local"},
},
"spec": {
"storageClassName": "manual",
"capacity": {"storage": storage_capacity},
"accessModes": ["ReadWriteOnce"],
"hostPath": {"path": path},
},
}
build_pod_manifest(pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)
Build a Kubernetes pod manifest for a ZenML run or step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pod_name |
str |
Name of the pod. |
required |
run_name |
str |
Name of the ZenML run. |
required |
pipeline_name |
str |
Name of the ZenML pipeline. |
required |
image_name |
str |
Name of the Docker image. |
required |
command |
List[str] |
Command to execute the entrypoint in the pod. |
required |
args |
List[str] |
Arguments provided to the entrypoint command. |
required |
service_account_name |
Optional[str] |
Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster. |
None |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Pod manifest. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_pod_manifest(
pod_name: str,
run_name: str,
pipeline_name: str,
image_name: str,
command: List[str],
args: List[str],
service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Build a Kubernetes pod manifest for a ZenML run or step.
Args:
pod_name: Name of the pod.
run_name: Name of the ZenML run.
pipeline_name: Name of the ZenML pipeline.
image_name: Name of the Docker image.
command: Command to execute the entrypoint in the pod.
args: Arguments provided to the entrypoint command.
service_account_name: Optional name of a service account.
Can be used to assign certain roles to a pod, e.g., to allow it to
run Kubernetes commands from within the cluster.
Returns:
Pod manifest.
"""
manifest = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {
"name": pod_name,
"labels": {
"run": run_name,
"pipeline": pipeline_name,
},
},
"spec": {
"restartPolicy": "Never",
"containers": [
{
"name": "main",
"image": image_name,
"command": command,
"args": args,
"env": [
{
"name": ENV_ZENML_ENABLE_REPO_INIT_WARNINGS,
"value": "False",
}
],
}
],
},
}
if service_account_name is not None:
spec = cast(Dict[str, Any], manifest["spec"]) # mypy stupid
spec["serviceAccountName"] = service_account_name
return manifest
build_service_account_manifest(name, namespace='default')
Build the manifest for a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the service account. |
required |
namespace |
str |
Kubernetes namespace. Defaults to "default". |
'default' |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Manifest for a service account. |
Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_service_account_manifest(
name: str, namespace: str = "default"
) -> Dict[str, Any]:
"""Build the manifest for a service account.
Args:
name: Name of the service account.
namespace: Kubernetes namespace. Defaults to "default".
Returns:
Manifest for a service account.
"""
return {
"apiVersion": "v1",
"metadata": {
"name": name,
"namespace": namespace,
},
}
label_studio
special
Initialization of the Label Studio integration.
LabelStudioIntegration (Integration)
Definition of Label Studio integration for ZenML.
Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
"""Definition of Label Studio integration for ZenML."""
NAME = LABEL_STUDIO
REQUIREMENTS = ["label-studio", "label-studio-sdk"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.label_studio.flavors import (
LabelStudioAnnotatorFlavor,
)
return [LabelStudioAnnotatorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Label Studio integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.label_studio.flavors import (
LabelStudioAnnotatorFlavor,
)
return [LabelStudioAnnotatorFlavor]
annotators
special
Initialization of the Label Studio annotators submodule.
label_studio_annotator
Implementation of the Label Studio annotation integration.
LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin)
Class to interact with the Label Studio annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
"""Class to interact with the Label Studio annotation interface."""
@property
def config(self) -> LabelStudioAnnotatorConfig:
"""Returns the `LabelStudioAnnotatorConfig` config.
Returns:
The configuration.
"""
return cast(LabelStudioAnnotatorConfig, self._config)
@property
def validator(self) -> Optional["StackValidator"]:
"""Validates that the stack contains a cloud artifact store.
Returns:
StackValidator: Validator for the stack.
"""
def _ensure_cloud_artifact_stores(stack: Stack) -> Tuple[bool, str]:
# For now this only works on cloud artifact stores.
return (
stack.artifact_store.flavor
in [
AZURE_ARTIFACT_STORE_FLAVOR,
GCP_ARTIFACT_STORE_FLAVOR,
S3_ARTIFACT_STORE_FLAVOR,
],
"Only cloud artifact stores are currently supported",
)
return StackValidator(
required_components={StackComponentType.SECRETS_MANAGER},
custom_validation_function=_ensure_cloud_artifact_stores,
)
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"http://localhost:{self.config.port}"
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
def _get_client(self) -> Client:
"""Gets Label Studio client.
Returns:
Label Studio client.
Raises:
ValueError: when unable to access the Label Studio API key.
"""
secret = self.get_authentication_secret(ArbitrarySecretSchema)
if not secret:
raise ValueError(
f"Unable to access predefined secret '{secret}' to access Label Studio API key."
)
api_key = secret.content["api_key"]
return Client(url=self.get_url(), api_key=api_key)
def _connection_available(self) -> bool:
"""Checks if the connection to the annotation server is available.
Returns:
True if the connection is available, False otherwise.
"""
try:
result = self._get_client().check_connection()
return result.get("status") == "UP" # type: ignore[no-any-return]
# TODO: [HIGH] refactor to use a more specific exception
except Exception:
logger.error(
"Connection error: No connection was able to be established to the Label Studio backend."
)
return False
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
NotImplementedError: If the deletion of a dataset is not supported.
"""
raise NotImplementedError("Awaiting Label Studio release.")
# TODO: Awaiting a new Label Studio version to be released with this method
# ls = self._get_client()
# dataset_name = kwargs.get("dataset_name")
# if not dataset_name:
# raise ValueError("`dataset_name` keyword argument is required.")
# dataset_id = self.get_id_from_name(dataset_name)
# if not dataset_id:
# raise ValueError(
# f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
# )
# ls.delete_project(dataset_id)
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
def register_dataset_for_annotation(
self,
params: LabelStudioDatasetRegistrationParameters,
) -> Any:
"""Registers a dataset for annotation.
Args:
params: Parameters for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(params.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=params.dataset_name,
label_config=params.label_config,
)
return dataset
def _get_azure_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Azure import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Azure import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/azure?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_gcs_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Google Cloud Storage import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Google Cloud Storage import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/gcs?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_s3_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all AWS S3 import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of AWS S3 import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/s3?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _storage_source_already_exists(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> bool:
"""Returns whether a storage source already exists.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
True if the storage source already exists, False otherwise.
Raises:
NotImplementedError: If the storage source type is not supported.
"""
# TODO: check we are already connected
dataset_id = int(dataset.get_params()["id"])
if params.storage_type == "azure":
storage_sources = self._get_azure_import_storage_sources(dataset_id)
elif params.storage_type == "gcs":
storage_sources = self._get_gcs_import_storage_sources(dataset_id)
elif params.storage_type == "s3":
storage_sources = self._get_s3_import_storage_sources(dataset_id)
else:
raise NotImplementedError(
f"Storage type '{params.storage_type}' not implemented."
)
return any(
(
source.get("presign") == params.presign
and source.get("bucket") == uri
and source.get("regex_filter") == params.regex_filter
and source.get("use_blob_urls") == params.use_blob_urls
and source.get("title") == dataset.get_params()["title"]
and source.get("description") == params.description
and source.get("presign_ttl") == params.presign_ttl
and source.get("project") == dataset_id
)
for source in storage_sources
)
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
def connect_and_sync_external_storage(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
}
if params.storage_type == "azure":
if not params.azure_account_name or not params.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=params.azure_account_name,
account_key=params.azure_account_key,
**storage_connection_args,
)
elif params.storage_type == "gcs":
if not params.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=params.google_application_credentials,
**storage_connection_args,
)
elif params.storage_type == "s3":
if not params.aws_access_key_id or not params.aws_secret_access_key:
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
" Label Studio web UI as per your needs."
)
storage = dataset.connect_s3_import_storage(
bucket=uri,
aws_access_key_id=params.aws_access_key_id,
aws_secret_access_key=params.aws_secret_access_key,
aws_session_token=params.aws_session_token,
region_name=params.s3_region_name,
s3_endpoint=params.s3_endpoint,
**storage_connection_args,
)
else:
raise ValueError(
f"Invalid storage type. '{params.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
@property
def root_directory(self) -> str:
"""Returns path to the root directory.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"annotators",
str(self.id),
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "label_studio_daemon.pid")
@property
def _log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path to the daemon log file.
"""
return os.path.join(self.root_directory, "label_studio_daemon.log")
@property
def is_provisioned(self) -> bool:
"""If the component provisioned resources to run locally.
Returns:
True if the component provisioned resources to run locally.
"""
return fileio.exists(self.root_directory)
@property
def is_running(self) -> bool:
"""If the component is running locally.
Returns:
True if the component is running locally, False otherwise.
"""
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
if not check_if_daemon_is_running(self._pid_file_path):
return False
else:
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
return True
def provision(self) -> None:
"""Spins up the annotation server backend."""
fileio.makedirs(self.root_directory)
def deprovision(self) -> None:
"""Spins down the annotation server backend."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
def resume(self) -> None:
"""Resumes the annotation interface."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
self.start_annotator_daemon()
def suspend(self) -> None:
"""Suspends the annotation interface."""
if not self.is_running:
logger.info("Local annotation server is not running.")
return
self.stop_annotator_daemon()
def start_annotator_daemon(self) -> None:
"""Starts the annotation server backend.
Raises:
ProvisioningError: If the annotation server backend is already
running or the port is already occupied.
"""
command = [
"label-studio",
"start",
"--no-browser",
"--port",
f"{self.config.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Label Studio server locally, "
"please run '%s' in a separate command line shell.",
self.config.port,
" ".join(command),
)
elif not networking_utils.port_available(self.config.port):
raise ProvisioningError(
f"Unable to port-forward Label Studio to local "
f"port {self.config.port} because the port is occupied. In order to "
f"access Label Studio locally, please "
f"change the configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Label Studio daemon (check the daemon"
"logs at `%s` in case you're not able to access the annotation "
f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
self._log_file,
)
def stop_annotator_daemon(self) -> None:
"""Stops the annotation server backend."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
config: LabelStudioAnnotatorConfig
property
readonly
Returns the LabelStudioAnnotatorConfig
config.
Returns:
Type | Description |
---|---|
LabelStudioAnnotatorConfig |
The configuration. |
is_provisioned: bool
property
readonly
If the component provisioned resources to run locally.
Returns:
Type | Description |
---|---|
bool |
True if the component provisioned resources to run locally. |
is_running: bool
property
readonly
If the component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if the component is running locally, False otherwise. |
root_directory: str
property
readonly
Returns path to the root directory.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
validator: Optional[StackValidator]
property
readonly
Validates that the stack contains a cloud artifact store.
Returns:
Type | Description |
---|---|
StackValidator |
Validator for the stack. |
add_dataset(self, **kwargs)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Exceptions:
Type | Description |
---|---|
ValueError |
if 'dataset_name' and 'label_config' aren't provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
connect_and_sync_external_storage(self, uri, params, dataset)
Syncs the external storage for the given project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI of the storage source. |
required |
params |
LabelStudioDatasetSyncParameters |
Parameters for the dataset. |
required |
dataset |
Project |
Label Studio dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[Dict[str, Any]] |
A dictionary containing the sync result. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the storage type is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
}
if params.storage_type == "azure":
if not params.azure_account_name or not params.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=params.azure_account_name,
account_key=params.azure_account_key,
**storage_connection_args,
)
elif params.storage_type == "gcs":
if not params.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=params.google_application_credentials,
**storage_connection_args,
)
elif params.storage_type == "s3":
if not params.aws_access_key_id or not params.aws_secret_access_key:
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
" Label Studio web UI as per your needs."
)
storage = dataset.connect_s3_import_storage(
bucket=uri,
aws_access_key_id=params.aws_access_key_id,
aws_secret_access_key=params.aws_secret_access_key,
aws_session_token=params.aws_session_token,
region_name=params.s3_region_name,
s3_endpoint=params.s3_endpoint,
**storage_connection_args,
)
else:
raise ValueError(
f"Invalid storage type. '{params.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)
Deletes a dataset from the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
If the deletion of a dataset is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
NotImplementedError: If the deletion of a dataset is not supported.
"""
raise NotImplementedError("Awaiting Label Studio release.")
# TODO: Awaiting a new Label Studio version to be released with this method
# ls = self._get_client()
# dataset_name = kwargs.get("dataset_name")
# if not dataset_name:
# raise ValueError("`dataset_name` keyword argument is required.")
# dataset_id = self.get_id_from_name(dataset_name)
# if not dataset_id:
# raise ValueError(
# f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
# )
# ls.delete_project(dataset_id)
deprovision(self)
Spins down the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def deprovision(self) -> None:
"""Spins down the annotation server backend."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
get_converted_dataset(self, dataset_name, output_format)
Extract annotated tasks in a specific converted format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Id of the dataset. |
required |
output_format |
str |
Output format. |
required |
Returns:
Type | Description |
---|---|
Dict[Any, Any] |
A dictionary containing the converted dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
get_dataset(self, **kwargs)
Gets the dataset with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The LabelStudio Dataset object (a 'Project') for the given name. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
get_dataset_names(self)
Gets the names of the datasets.
Returns:
Type | Description |
---|---|
List[str] |
A list of dataset names. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
get_dataset_stats(self, dataset_name)
Gets the statistics of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Tuple[int, int] |
A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset. |
Exceptions:
Type | Description |
---|---|
IndexError |
If the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
get_datasets(self)
Gets the datasets currently available for annotation.
Returns:
Type | Description |
---|---|
List[Any] |
A list of datasets. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)
Gets the ID of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[int] |
The ID of the dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
get_labeled_data(self, **kwargs)
Gets the labeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The labeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)
Returns the parsed Label Studio label config for a dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_id |
int |
Id of the dataset. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the parsed label config. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no dataset is found for the given id. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)
Gets the unlabeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
str |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The unlabeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)
Gets the top-level URL of the annotation interface.
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"http://localhost:{self.config.port}"
get_url_for_dataset(self, dataset_name)
Gets the URL of the annotation interface for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)
Launches the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
Optional[str] |
The URL of the annotation interface. |
required |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
provision(self)
Spins up the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def provision(self) -> None:
"""Spins up the annotation server backend."""
fileio.makedirs(self.root_directory)
register_dataset_for_annotation(self, params)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetRegistrationParameters |
Parameters for the dataset. |
required |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
self,
params: LabelStudioDatasetRegistrationParameters,
) -> Any:
"""Registers a dataset for annotation.
Args:
params: Parameters for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(params.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=params.dataset_name,
label_config=params.label_config,
)
return dataset
resume(self)
Resumes the annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def resume(self) -> None:
"""Resumes the annotation interface."""
if self.is_running:
logger.info("Local kubeflow pipelines deployment already running.")
return
self.start_annotator_daemon()
start_annotator_daemon(self)
Starts the annotation server backend.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the annotation server backend is already running or the port is already occupied. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def start_annotator_daemon(self) -> None:
"""Starts the annotation server backend.
Raises:
ProvisioningError: If the annotation server backend is already
running or the port is already occupied.
"""
command = [
"label-studio",
"start",
"--no-browser",
"--port",
f"{self.config.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Label Studio server locally, "
"please run '%s' in a separate command line shell.",
self.config.port,
" ".join(command),
)
elif not networking_utils.port_available(self.config.port):
raise ProvisioningError(
f"Unable to port-forward Label Studio to local "
f"port {self.config.port} because the port is occupied. In order to "
f"access Label Studio locally, please "
f"change the configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Label Studio daemon (check the daemon"
"logs at `%s` in case you're not able to access the annotation "
f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
self._log_file,
)
stop_annotator_daemon(self)
Stops the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def stop_annotator_daemon(self) -> None:
"""Stops the annotation server backend."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
suspend(self)
Suspends the annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def suspend(self) -> None:
"""Suspends the annotation interface."""
if not self.is_running:
logger.info("Local annotation server is not running.")
return
self.stop_annotator_daemon()
flavors
special
Label Studio integration flavors.
label_studio_annotator_flavor
Label Studio annotator flavor.
LabelStudioAnnotatorConfig (BaseAnnotatorConfig, AuthenticationConfigMixin)
pydantic-model
Config for the Label Studio annotator.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The port to use for the annotation interface. |
Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorConfig(
BaseAnnotatorConfig, AuthenticationConfigMixin
):
"""Config for the Label Studio annotator.
Attributes:
port: The port to use for the annotation interface.
"""
port: int = DEFAULT_LABEL_STUDIO_PORT
LabelStudioAnnotatorFlavor (BaseAnnotatorFlavor)
Label Studio annotator flavor.
Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorFlavor(BaseAnnotatorFlavor):
"""Label Studio annotator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return LABEL_STUDIO_ANNOTATOR_FLAVOR
@property
def config_class(self) -> Type[LabelStudioAnnotatorConfig]:
"""Returns `LabelStudioAnnotatorConfig` config class.
Returns:
The config class.
"""
return LabelStudioAnnotatorConfig
@property
def implementation_class(self) -> Type["LabelStudioAnnotator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.label_studio.annotators import (
LabelStudioAnnotator,
)
return LabelStudioAnnotator
config_class: Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig]
property
readonly
Returns LabelStudioAnnotatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig] |
The config class. |
implementation_class: Type[LabelStudioAnnotator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[LabelStudioAnnotator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
label_config_generators
special
Initialization of the Label Studio config generators submodule.
label_config_generators
Implementation of label config generators for Label Studio.
generate_basic_object_detection_bounding_boxes_label_config(labels)
Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at
https://labelstud.io/templates/image_bbox.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES
label_config_start = """<View>
<Image name="image" value="$image"/>
<RectangleLabels name="label" toName="image">
"""
label_config_choices = "".join(
f"<Label value='{label}' />\n" for label in labels
)
label_config_end = "</RectangleLabels>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
generate_image_classification_label_config(labels)
Generates a Label Studio label config for image classification.
This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio label config for image classification.
This is based on the basic config example shown at
https://labelstud.io/templates/image_classification.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION
label_config_start = """<View>
<Image name="image" value="$image"/>
<Choices name="choice" toName="image">
"""
label_config_choices = "".join(
f"<Choice value='{label}' />\n" for label in labels
)
label_config_end = "</Choices>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
label_studio_utils
Utility functions for the Label Studio annotator integration.
convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)
Converts a list of predictions from local file references to task id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
List[Dict[str, Any]] |
List of predictions. |
required |
tasks |
List[Dict[str, Any]] |
List of tasks. |
required |
filename_reference |
str |
Name of the file reference in the predictions. |
required |
storage_type |
str |
Storage type of the predictions. |
required |
Returns:
Type | Description |
---|---|
List[Dict[str, Any]] |
List of predictions using task ids as reference. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
preds: List[Dict[str, Any]],
tasks: List[Dict[str, Any]],
filename_reference: str,
storage_type: str,
) -> List[Dict[str, Any]]:
"""Converts a list of predictions from local file references to task id.
Args:
preds: List of predictions.
tasks: List of tasks.
filename_reference: Name of the file reference in the predictions.
storage_type: Storage type of the predictions.
Returns:
List of predictions using task ids as reference.
"""
filename_id_mapping = {
os.path.basename(urlparse(task["data"][filename_reference]).path): task[
"id"
]
for task in tasks
}
# GCS and S3 URL encodes filenames containing spaces, requiring this
# separate encoding step
if storage_type in {"gcs", "s3"}:
preds = [
{"filename": quote(pred["filename"]), "result": pred["result"]}
for pred in preds
]
return [
{
"task": int(
filename_id_mapping[os.path.basename(pred["filename"])]
),
"result": pred["result"],
}
for pred in preds
]
get_file_extension(path_str)
Return the file extension of the given filename.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path_str |
str |
Path to the file. |
required |
Returns:
Type | Description |
---|---|
str |
File extension. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
"""Return the file extension of the given filename.
Args:
path_str: Path to the file.
Returns:
File extension.
"""
return os.path.splitext(urlparse(path_str).path)[1]
is_azure_url(url)
Return whether the given URL is an Azure URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an Azure URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
"""Return whether the given URL is an Azure URL.
Args:
url: URL to check.
Returns:
True if the URL is an Azure URL, False otherwise.
"""
return "blob.core.windows.net" in urlparse(url).netloc
is_gcs_url(url)
Return whether the given URL is an GCS URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an GCS URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
"""Return whether the given URL is an GCS URL.
Args:
url: URL to check.
Returns:
True if the URL is an GCS URL, False otherwise.
"""
return "storage.googleapis.com" in urlparse(url).netloc
is_s3_url(url)
Return whether the given URL is an S3 URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an S3 URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
"""Return whether the given URL is an S3 URL.
Args:
url: URL to check.
Returns:
True if the URL is an S3 URL, False otherwise.
"""
return "s3.amazonaws" in urlparse(url).netloc
steps
special
Standard steps to be used with the Label Studio annotator integration.
label_studio_standard_steps
Implementation of standard steps for the Label Studio annotator integration.
LabelStudioDatasetRegistrationParameters (BaseParameters)
pydantic-model
Step parameters when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
"""Step parameters when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
LabelStudioDatasetSyncParameters (BaseParameters)
pydantic-model
Step parameters when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
"""Step parameters when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to.
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
get_labeled_data (BaseStep)
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
Name of the dataset. |
required | |
context |
The StepContext. |
required |
Returns:
Type | Description |
---|---|
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
entrypoint(dataset_name, context)
staticmethod
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Name of the dataset. |
required |
context |
StepContext |
The StepContext. |
required |
Returns:
Type | Description |
---|---|
List |
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str, context: StepContext) -> List: # type: ignore[type-arg]
"""Gets labeled data from the dataset.
Args:
dataset_name: Name of the dataset.
context: The StepContext.
Returns:
List of labeled data.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
# TODO [MEDIUM]: have this check for new data *since the last time this step ran*
annotator = context.stack.annotator # type: ignore[union-attr]
if not annotator:
raise StackComponentInterfaceError("No active annotator.")
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator._connection_available():
dataset = annotator.get_dataset(dataset_name=dataset_name)
return dataset.get_labeled_tasks() # type: ignore[no-any-return]
raise StackComponentInterfaceError(
"Unable to connect to annotator stack component."
)
get_or_create_dataset (BaseStep)
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Step parameters. |
required | |
context |
Step context. |
required |
Returns:
Type | Description |
---|---|
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Step parameters when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
"""Step parameters when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
entrypoint(params, context)
staticmethod
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetRegistrationParameters |
Step parameters. |
required |
context |
StepContext |
Step context. |
required |
Returns:
Type | Description |
---|---|
str |
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
params: LabelStudioDatasetRegistrationParameters,
context: StepContext,
) -> str:
"""Gets preexisting dataset or creates a new one.
Args:
params: Step parameters.
context: Step context.
Returns:
The dataset name.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = context.stack.annotator # type: ignore[union-attr]
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator and annotator._connection_available():
for dataset in annotator.get_datasets():
if dataset.get_params()["title"] == params.dataset_name:
return cast(str, dataset.get_params()["title"])
dataset = annotator.register_dataset_for_annotation(params)
return cast(str, dataset.get_params()["title"])
raise StackComponentInterfaceError("No active annotator.")
# if annotator and annotator._connection_available():
# preexisting_dataset_list = [
# dataset
# for dataset in annotator.get_datasets()
# if dataset.get_params()["title"] == config.dataset_name
# ]
# if (
# not preexisting_dataset_list
# and annotator
# and annotator._connection_available()
# ):
# registered_dataset = annotator.register_dataset_for_annotation(
# config
# )
# elif preexisting_dataset_list:
# return cast(str, preexisting_dataset_list[0].get_params()["title"])
# else:
# raise StackComponentInterfaceError("No active annotator.")
# return cast(str, registered_dataset.get_params()["title"])
# else:
# raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (BaseStep)
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
The URI of the data to sync. |
required | |
dataset_name |
The name of the dataset to sync to. |
required | |
predictions |
The predictions to sync. |
required | |
params |
The parameters for the sync. |
required | |
context |
The StepContext. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Step parameters when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
"""Step parameters when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to.
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
entrypoint(uri, dataset_name, predictions, params, context)
staticmethod
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI of the data to sync. |
required |
dataset_name |
str |
The name of the dataset to sync to. |
required |
predictions |
List[Dict[str, Any]] |
The predictions to sync. |
required |
params |
LabelStudioDatasetSyncParameters |
The parameters for the sync. |
required |
context |
StepContext |
The StepContext. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
uri: str,
dataset_name: str,
predictions: List[Dict[str, Any]],
params: LabelStudioDatasetSyncParameters,
context: StepContext,
) -> None:
"""Syncs new data to Label Studio.
Args:
uri: The URI of the data to sync.
dataset_name: The name of the dataset to sync to.
predictions: The predictions to sync.
params: The parameters for the sync.
context: The StepContext.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
ValueError: if you are trying to sync from outside ZenML.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = context.stack.annotator # type: ignore[union-attr]
artifact_store = context.stack.artifact_store # type: ignore[union-attr]
secrets_manager = context.stack.secrets_manager # type: ignore[union-attr]
if not annotator or not artifact_store or not secrets_manager:
raise StackComponentInterfaceError(
"An active annotator, artifact store and secrets manager are required to run this step."
)
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
# TODO: check that annotator is connected before querying it
dataset = annotator.get_dataset(dataset_name=dataset_name)
if not uri.startswith(artifact_store.path):
raise ValueError(
"ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
)
# removes the initial forward slash from the prefix attribute by slicing
params.prefix = urlparse(uri).path.lstrip("/")
base_uri = urlparse(uri).netloc
# gets the secret used for authentication
if params.storage_type == "azure":
if not isinstance(artifact_store, AuthenticationMixin):
raise TypeError(
"The artifact store must inherit from "
f"{AuthenticationMixin.__name__} to work with a Label Studio "
f"`{params.storage_type}` storage."
)
azure_secret = artifact_store.get_authentication_secret(
expected_schema_type=AzureSecretSchema
)
if not azure_secret:
raise ValueError(
"Missing secret to authenticate cloud storage for Label Studio."
)
params.azure_account_name = azure_secret.account_name
params.azure_account_key = azure_secret.account_key
elif params.storage_type == "gcs":
if not isinstance(artifact_store, AuthenticationMixin):
raise TypeError(
"The artifact store must inherit from "
f"{AuthenticationMixin.__name__} to work with a Label Studio "
f"`{params.storage_type}` storage."
)
gcp_secret = artifact_store.get_authentication_secret(
expected_schema_type=GCPSecretSchema
)
if not gcp_secret:
raise ValueError(
"Missing secret to authenticate cloud storage for Label Studio."
)
params.google_application_credentials = gcp_secret.token
elif params.storage_type == "s3":
aws_secret = secrets_manager.get_secret(LABEL_STUDIO_AWS_SECRET_NAME)
if not isinstance(aws_secret, AWSSecretSchema):
raise TypeError(
f"The secret `{LABEL_STUDIO_AWS_SECRET_NAME}` needs to be "
f"an `aws` schema secret."
)
params.aws_access_key_id = aws_secret.aws_access_key_id
params.aws_secret_access_key = aws_secret.aws_secret_access_key
params.aws_session_token = aws_secret.aws_session_token
if annotator and annotator._connection_available():
# TODO: get existing (CHECK!) or create the sync connection
annotator.connect_and_sync_external_storage(
uri=base_uri,
params=params,
dataset=dataset,
)
if predictions:
filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
params.label_config_type
]
preds_with_task_ids = convert_pred_filenames_to_task_ids(
predictions,
dataset.tasks,
filename_reference,
params.storage_type,
)
# TODO: filter out any predictions that exist + have already been
# made (maybe?). Only pass in preds for tasks without pre-annotations.
dataset.create_predictions(preds_with_task_ids)
else:
raise StackComponentInterfaceError("No active annotator.")
lightgbm
special
Initialization of the LightGBM integration.
LightGBMIntegration (Integration)
Definition of lightgbm integration for ZenML.
Source code in zenml/integrations/lightgbm/__init__.py
class LightGBMIntegration(Integration):
"""Definition of lightgbm integration for ZenML."""
NAME = LIGHTGBM
REQUIREMENTS = ["lightgbm>=1.0.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.lightgbm import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/lightgbm/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.lightgbm import materializers # noqa
materializers
special
Initialization of the Neural Prophet materializer.
lightgbm_booster_materializer
Implementation of the LightGBM booster materializer.
LightGBMBoosterMaterializer (BaseMaterializer)
Materializer to read data to and from lightgbm.Booster.
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
class LightGBMBoosterMaterializer(BaseMaterializer):
"""Materializer to read data to and from lightgbm.Booster."""
ASSOCIATED_TYPES = (lgb.Booster,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
"""Reads a lightgbm Booster model from a serialized JSON file.
Args:
data_type: A lightgbm Booster type.
Returns:
A lightgbm Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = lgb.Booster(model_file=temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
def handle_return(self, booster: lgb.Booster) -> None:
"""Creates a JSON serialization for a lightgbm Booster model.
Args:
booster: A lightgbm Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
handle_input(self, data_type)
Reads a lightgbm Booster model from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A lightgbm Booster type. |
required |
Returns:
Type | Description |
---|---|
Booster |
A lightgbm Booster object. |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
"""Reads a lightgbm Booster model from a serialized JSON file.
Args:
data_type: A lightgbm Booster type.
Returns:
A lightgbm Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = lgb.Booster(model_file=temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
handle_return(self, booster)
Creates a JSON serialization for a lightgbm Booster model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
booster |
Booster |
A lightgbm Booster model. |
required |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_return(self, booster: lgb.Booster) -> None:
"""Creates a JSON serialization for a lightgbm Booster model.
Args:
booster: A lightgbm Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
lightgbm_dataset_materializer
Implementation of the LightGBM materializer.
LightGBMDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from lightgbm.Dataset.
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
class LightGBMDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from lightgbm.Dataset."""
ASSOCIATED_TYPES = (lgb.Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
"""Reads a lightgbm.Dataset binary file and loads it.
Args:
data_type: A lightgbm.Dataset type.
Returns:
A lightgbm.Dataset object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = lgb.Dataset(temp_file, free_raw_data=False)
# No clean up this time because matrix is lazy loaded
return matrix
def handle_return(self, matrix: lgb.Dataset) -> None:
"""Creates a binary serialization for a lightgbm.Dataset object.
Args:
matrix: A lightgbm.Dataset object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
matrix.save_binary(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
handle_input(self, data_type)
Reads a lightgbm.Dataset binary file and loads it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A lightgbm.Dataset type. |
required |
Returns:
Type | Description |
---|---|
Dataset |
A lightgbm.Dataset object. |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
"""Reads a lightgbm.Dataset binary file and loads it.
Args:
data_type: A lightgbm.Dataset type.
Returns:
A lightgbm.Dataset object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = lgb.Dataset(temp_file, free_raw_data=False)
# No clean up this time because matrix is lazy loaded
return matrix
handle_return(self, matrix)
Creates a binary serialization for a lightgbm.Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
matrix |
Dataset |
A lightgbm.Dataset object. |
required |
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_return(self, matrix: lgb.Dataset) -> None:
"""Creates a binary serialization for a lightgbm.Dataset object.
Args:
matrix: A lightgbm.Dataset object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
matrix.save_binary(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
mlflow
special
Initialization for the ZenML MLflow integration.
The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.
MlflowIntegration (Integration)
Definition of MLflow integration for ZenML.
Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
"""Definition of MLflow integration for ZenML."""
NAME = MLFLOW
REQUIREMENTS = [
"mlflow>=1.2.0,<1.26.0",
"mlserver>=0.5.3",
"mlserver-mlflow>=0.5.3",
]
@classmethod
def activate(cls) -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the MLflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.mlflow.flavors import (
MLFlowExperimentTrackerFlavor,
MLFlowModelDeployerFlavor,
)
return [MLFlowModelDeployerFlavor, MLFlowExperimentTrackerFlavor]
activate()
classmethod
Activate the MLflow integration.
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
flavors()
classmethod
Declare the stack component flavors for the MLflow integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the MLflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.mlflow.flavors import (
MLFlowExperimentTrackerFlavor,
MLFlowModelDeployerFlavor,
)
return [MLFlowModelDeployerFlavor, MLFlowExperimentTrackerFlavor]
experiment_trackers
special
Initialization of the MLflow experiment tracker.
mlflow_experiment_tracker
Implementation of the MLflow experiment tracker for ZenML.
MLFlowExperimentTracker (BaseExperimentTracker)
Track experiments using MLflow.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
"""Track experiments using MLflow."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the experiment tracker and validate the tracking uri.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._ensure_valid_tracking_uri()
def _ensure_valid_tracking_uri(self) -> None:
"""Ensures that the tracking uri is a valid mlflow tracking uri.
Raises:
ValueError: If the tracking uri is not valid.
"""
tracking_uri = self.config.tracking_uri
if tracking_uri:
valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
if not any(
tracking_uri.startswith(scheme) for scheme in valid_schemes
) and not is_databricks_tracking_uri(tracking_uri):
raise ValueError(
f"MLflow tracking uri does not start with one of the valid "
f"schemes {valid_schemes} or its value is not set to "
f"'databricks'. See "
f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
f"for more information."
)
@property
def config(self) -> MLFlowExperimentTrackerConfig:
"""Returns the `MLFlowExperimentTrackerConfig` config.
Returns:
The configuration.
"""
return cast(MLFlowExperimentTrackerConfig, self._config)
@property
def local_path(self) -> Optional[str]:
"""Path to the local directory where the MLflow artifacts are stored.
Returns:
None if configured with a remote tracking URI, otherwise the
path to the local MLflow artifact store directory.
"""
tracking_uri = self.get_tracking_uri()
if is_remote_mlflow_tracking_uri(tracking_uri):
return None
else:
assert tracking_uri.startswith("file:")
return tracking_uri[5:]
@property
def validator(self) -> Optional["StackValidator"]:
"""Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.
Returns:
An optional `StackValidator`.
"""
if self.config.tracking_uri:
# user specified a tracking uri, do nothing
return None
else:
# try to fall back to a tracking uri inside the zenml artifact
# store. this only works in case of a local artifact store, so we
# make sure to prevent stack with other artifact stores for now
return StackValidator(
custom_validation_function=lambda stack: (
isinstance(stack.artifact_store, LocalArtifactStore),
"MLflow experiment tracker without a specified tracking "
"uri only works with a local artifact store.",
)
)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Mlflow experiment tracker.
Returns:
The settings class.
"""
return MLFlowExperimentTrackerSettings
@staticmethod
def _local_mlflow_backend() -> str:
"""Gets the local MLflow backend inside the ZenML artifact repository directory.
Returns:
The MLflow tracking URI for the local MLflow backend.
"""
client = Client(skip_client_check=True) # type: ignore[call-arg]
artifact_store = client.active_stack.artifact_store
local_mlflow_backend_uri = os.path.join(artifact_store.path, "mlruns")
if not os.path.exists(local_mlflow_backend_uri):
os.makedirs(local_mlflow_backend_uri)
return "file:" + local_mlflow_backend_uri
def get_tracking_uri(self) -> str:
"""Returns the configured tracking URI or a local fallback.
Returns:
The tracking URI.
"""
return self.config.tracking_uri or self._local_mlflow_backend()
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Sets the MLflow tracking uri and credentials.
Args:
info: Info about the step that will be executed.
"""
self.configure_mlflow()
settings = cast(
MLFlowExperimentTrackerSettings,
self.get_settings(info) or MLFlowExperimentTrackerSettings(),
)
experiment_name = settings.experiment_name or info.pipeline.name
experiment = self._set_active_experiment(experiment_name)
run_id = self.get_run_id(
experiment_name=experiment_name, run_name=info.run_name
)
tags = settings.tags.copy()
tags.update(self._get_internal_tags())
mlflow.start_run(
run_id=run_id,
run_name=info.run_name,
experiment_id=experiment.experiment_id,
tags=tags,
)
if settings.nested:
mlflow.start_run(run_name=info.config.name, nested=True, tags=tags)
def cleanup_step_run(self, info: "StepRunInfo") -> None:
"""Stops active MLflow runs and resets the MLflow tracking uri.
Args:
info: Info about the step that was executed.
"""
mlflow_utils.stop_zenml_mlflow_runs()
mlflow.set_tracking_uri("")
def configure_mlflow(self) -> None:
"""Configures the MLflow tracking URI and any additional credentials."""
tracking_uri = self.get_tracking_uri()
mlflow.set_tracking_uri(tracking_uri)
if is_databricks_tracking_uri(tracking_uri):
if self.config.databricks_host:
os.environ[DATABRICKS_HOST] = self.config.databricks_host
if self.config.tracking_username:
os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
if self.config.tracking_password:
os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
if self.config.tracking_token:
os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
else:
if self.config.tracking_username:
os.environ[
MLFLOW_TRACKING_USERNAME
] = self.config.tracking_username
if self.config.tracking_password:
os.environ[
MLFLOW_TRACKING_PASSWORD
] = self.config.tracking_password
if self.config.tracking_token:
os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token
os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
"true" if self.config.tracking_insecure_tls else "false"
)
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
"""Gets the if of a run with the given name and experiment.
Args:
experiment_name: Name of the experiment in which to search for the
run.
run_name: Name of the run to search.
Returns:
The id of the run if it exists.
"""
self.configure_mlflow()
experiment_name = self._adjust_experiment_name(experiment_name)
runs = mlflow.search_runs(
experiment_names=[experiment_name],
filter_string=f'tags.mlflow.runName = "{run_name}"',
output_format="list",
)
if not runs:
return None
run: Run = runs[0]
if mlflow_utils.is_zenml_run(run):
return cast(str, run.info.run_id)
else:
return None
def _set_active_experiment(self, experiment_name: str) -> Experiment:
"""Sets the active MLflow experiment.
If no experiment with this name exists, it is created and then
activated.
Args:
experiment_name: Name of the experiment to activate.
Raises:
RuntimeError: If the experiment creation or activation failed.
Returns:
The experiment.
"""
experiment_name = self._adjust_experiment_name(experiment_name)
mlflow.set_experiment(experiment_name=experiment_name)
experiment = mlflow.get_experiment_by_name(experiment_name)
if not experiment:
raise RuntimeError("Failed to set active mlflow experiment.")
return experiment
def _adjust_experiment_name(self, experiment_name: str) -> str:
"""Prepends a slash to the experiment name if using Databricks.
Databricks requires the experiment name to be an absolute path within
the Databricks workspace.
Args:
experiment_name: The experiment name.
Returns:
The potentially adjusted experiment name.
"""
tracking_uri = self.get_tracking_uri()
if (
tracking_uri
and is_databricks_tracking_uri(tracking_uri)
and not experiment_name.startswith("/")
):
return f"/{experiment_name}"
else:
return experiment_name
@staticmethod
def _get_internal_tags() -> Dict[str, Any]:
"""Gets ZenML internal tags for MLflow runs.
Returns:
Internal tags.
"""
return {mlflow_utils.ZENML_TAG_KEY: zenml.__version__}
config: MLFlowExperimentTrackerConfig
property
readonly
Returns the MLFlowExperimentTrackerConfig
config.
Returns:
Type | Description |
---|---|
MLFlowExperimentTrackerConfig |
The configuration. |
local_path: Optional[str]
property
readonly
Path to the local directory where the MLflow artifacts are stored.
Returns:
Type | Description |
---|---|
Optional[str] |
None if configured with a remote tracking URI, otherwise the path to the local MLflow artifact store directory. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Mlflow experiment tracker.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[StackValidator]
property
readonly
Checks the stack has a LocalArtifactStore
if no tracking uri was specified.
Returns:
Type | Description |
---|---|
Optional[StackValidator] |
An optional |
__init__(self, *args, **kwargs)
special
Initialize the experiment tracker and validate the tracking uri.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Variable length argument list. |
() |
**kwargs |
Any |
Arbitrary keyword arguments. |
{} |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the experiment tracker and validate the tracking uri.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._ensure_valid_tracking_uri()
cleanup_step_run(self, info)
Stops active MLflow runs and resets the MLflow tracking uri.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(self, info: "StepRunInfo") -> None:
"""Stops active MLflow runs and resets the MLflow tracking uri.
Args:
info: Info about the step that was executed.
"""
mlflow_utils.stop_zenml_mlflow_runs()
mlflow.set_tracking_uri("")
configure_mlflow(self)
Configures the MLflow tracking URI and any additional credentials.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
"""Configures the MLflow tracking URI and any additional credentials."""
tracking_uri = self.get_tracking_uri()
mlflow.set_tracking_uri(tracking_uri)
if is_databricks_tracking_uri(tracking_uri):
if self.config.databricks_host:
os.environ[DATABRICKS_HOST] = self.config.databricks_host
if self.config.tracking_username:
os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
if self.config.tracking_password:
os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
if self.config.tracking_token:
os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
else:
if self.config.tracking_username:
os.environ[
MLFLOW_TRACKING_USERNAME
] = self.config.tracking_username
if self.config.tracking_password:
os.environ[
MLFLOW_TRACKING_PASSWORD
] = self.config.tracking_password
if self.config.tracking_token:
os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token
os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
"true" if self.config.tracking_insecure_tls else "false"
)
get_run_id(self, experiment_name, run_name)
Gets the if of a run with the given name and experiment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
experiment_name |
str |
Name of the experiment in which to search for the run. |
required |
run_name |
str |
Name of the run to search. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The id of the run if it exists. |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
"""Gets the if of a run with the given name and experiment.
Args:
experiment_name: Name of the experiment in which to search for the
run.
run_name: Name of the run to search.
Returns:
The id of the run if it exists.
"""
self.configure_mlflow()
experiment_name = self._adjust_experiment_name(experiment_name)
runs = mlflow.search_runs(
experiment_names=[experiment_name],
filter_string=f'tags.mlflow.runName = "{run_name}"',
output_format="list",
)
if not runs:
return None
run: Run = runs[0]
if mlflow_utils.is_zenml_run(run):
return cast(str, run.info.run_id)
else:
return None
get_tracking_uri(self)
Returns the configured tracking URI or a local fallback.
Returns:
Type | Description |
---|---|
str |
The tracking URI. |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self) -> str:
"""Returns the configured tracking URI or a local fallback.
Returns:
The tracking URI.
"""
return self.config.tracking_uri or self._local_mlflow_backend()
prepare_step_run(self, info)
Sets the MLflow tracking uri and credentials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that will be executed. |
required |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Sets the MLflow tracking uri and credentials.
Args:
info: Info about the step that will be executed.
"""
self.configure_mlflow()
settings = cast(
MLFlowExperimentTrackerSettings,
self.get_settings(info) or MLFlowExperimentTrackerSettings(),
)
experiment_name = settings.experiment_name or info.pipeline.name
experiment = self._set_active_experiment(experiment_name)
run_id = self.get_run_id(
experiment_name=experiment_name, run_name=info.run_name
)
tags = settings.tags.copy()
tags.update(self._get_internal_tags())
mlflow.start_run(
run_id=run_id,
run_name=info.run_name,
experiment_id=experiment.experiment_id,
tags=tags,
)
if settings.nested:
mlflow.start_run(run_name=info.config.name, nested=True, tags=tags)
flavors
special
MLFlow integration flavors.
mlflow_experiment_tracker_flavor
MLFlow experiment tracker flavor.
MLFlowExperimentTrackerConfig (BaseExperimentTrackerConfig)
pydantic-model
Config for the MLflow experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
tracking_uri |
Optional[str] |
The uri of the mlflow tracking server. If no uri is set,
your stack must contain a |
tracking_username |
Optional[str] |
Username for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_password |
Optional[str] |
Password for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_token |
Optional[str] |
Token for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_insecure_tls |
bool |
Skips verification of TLS connection to the
MLflow tracking server if set to |
databricks_host |
Optional[str] |
The host of the Databricks workspace with the MLflow
managed server to connect to. This is only required if
|
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerConfig(BaseExperimentTrackerConfig):
"""Config for the MLflow experiment tracker.
Attributes:
tracking_uri: The uri of the mlflow tracking server. If no uri is set,
your stack must contain a `LocalArtifactStore` and ZenML will
point MLflow to a subdirectory of your artifact store instead.
tracking_username: Username for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_password: Password for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_token: Token for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_insecure_tls: Skips verification of TLS connection to the
MLflow tracking server if set to `True`.
databricks_host: The host of the Databricks workspace with the MLflow
managed server to connect to. This is only required if
`tracking_uri` value is set to `"databricks"`.
"""
tracking_uri: Optional[str] = None
tracking_username: Optional[str] = SecretField()
tracking_password: Optional[str] = SecretField()
tracking_token: Optional[str] = SecretField()
tracking_insecure_tls: bool = False
databricks_host: Optional[str] = None
@root_validator(skip_on_failure=True)
def _ensure_authentication_if_necessary(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Ensures that credentials or a token for authentication exist.
We make this check when running MLflow tracking with a remote backend.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If neither credentials nor a token are provided.
"""
tracking_uri = values.get("tracking_uri")
if tracking_uri:
if is_databricks_tracking_uri(tracking_uri):
# If the tracking uri is "databricks", then we need the databricks
# host to be set.
databricks_host = values.get("databricks_host")
if not databricks_host:
raise ValueError(
"MLflow experiment tracking with a Databricks MLflow "
"managed tracking server requires the `databricks_host` "
"to be set in your stack component. To update your "
"component, run `zenml experiment-tracker update "
"<NAME> --databricks_host=DATABRICKS_HOST` "
"and specify the hostname of your Databricks workspace."
)
if is_remote_mlflow_tracking_uri(tracking_uri):
# we need either username + password or a token to authenticate to
# the remote backend
basic_auth = values.get("tracking_username") and values.get(
"tracking_password"
)
token_auth = values.get("tracking_token")
if not (basic_auth or token_auth):
raise ValueError(
f"MLflow experiment tracking with a remote backend "
f"{tracking_uri} is only possible when specifying either "
f"username and password or an authentication token in your "
f"stack component. To update your component, run the "
f"following command: `zenml experiment-tracker update "
f"<NAME> --tracking_username=MY_USERNAME "
f"--tracking_password=MY_PASSWORD "
f"--tracking_token=MY_TOKEN` and specify either your "
f"username and password or token."
)
return values
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
This designation is used to determine if the stack component can be
shared with other users or if it is only usable on the local host.
Returns:
True if this config is for a local component, False otherwise.
"""
if not self.tracking_uri or not is_remote_mlflow_tracking_uri(
self.tracking_uri
):
return True
return False
is_local: bool
property
readonly
Checks if this stack component is running locally.
This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
MLFlowExperimentTrackerFlavor (BaseExperimentTrackerFlavor)
Class for the MLFlowExperimentTrackerFlavor
.
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerFlavor(BaseExperimentTrackerFlavor):
"""Class for the `MLFlowExperimentTrackerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR
@property
def config_class(self) -> Type[MLFlowExperimentTrackerConfig]:
"""Returns `MLFlowExperimentTrackerConfig` config class.
Returns:
The config class.
"""
return MLFlowExperimentTrackerConfig
@property
def implementation_class(self) -> Type["MLFlowExperimentTracker"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.mlflow.experiment_trackers import (
MLFlowExperimentTracker,
)
return MLFlowExperimentTracker
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig]
property
readonly
Returns MLFlowExperimentTrackerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig] |
The config class. |
implementation_class: Type[MLFlowExperimentTracker]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[MLFlowExperimentTracker] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
MLFlowExperimentTrackerSettings (BaseSettings)
pydantic-model
Settings for the MLflow experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
experiment_name |
Optional[str] |
The MLflow experiment name. |
nested |
bool |
If |
tags |
Dict[str, Any] |
Tags for the Mlflow run. |
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerSettings(BaseSettings):
"""Settings for the MLflow experiment tracker.
Attributes:
experiment_name: The MLflow experiment name.
nested: If `True`, will create a nested sub-run for the step.
tags: Tags for the Mlflow run.
"""
experiment_name: Optional[str] = None
nested: bool = False
tags: Dict[str, Any] = {}
is_databricks_tracking_uri(tracking_uri)
Checks whether the given tracking uri is a Databricks tracking uri.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tracking_uri |
str |
The tracking uri to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_databricks_tracking_uri(tracking_uri: str) -> bool:
"""Checks whether the given tracking uri is a Databricks tracking uri.
Args:
tracking_uri: The tracking uri to check.
Returns:
`True` if the tracking uri is a Databricks tracking uri, `False`
otherwise.
"""
return tracking_uri == "databricks"
is_remote_mlflow_tracking_uri(tracking_uri)
Checks whether the given tracking uri is remote or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tracking_uri |
str |
The tracking uri to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_remote_mlflow_tracking_uri(tracking_uri: str) -> bool:
"""Checks whether the given tracking uri is remote or not.
Args:
tracking_uri: The tracking uri to check.
Returns:
`True` if the tracking uri is remote, `False` otherwise.
"""
return any(
tracking_uri.startswith(prefix) for prefix in ["http://", "https://"]
) or is_databricks_tracking_uri(tracking_uri)
mlflow_model_deployer_flavor
MLFlow model deployer flavor.
MLFlowModelDeployerConfig (BaseModelDeployerConfig)
pydantic-model
Configuration for the MLflow model deployer.
Attributes:
Name | Type | Description |
---|---|---|
service_path |
str |
the path where the local MLflow deployment service configuration, PID and log files are stored. |
Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerConfig(BaseModelDeployerConfig):
"""Configuration for the MLflow model deployer.
Attributes:
service_path: the path where the local MLflow deployment service
configuration, PID and log files are stored.
"""
service_path: str = ""
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
This designation is used to determine if the stack component can be
shared with other users or if it is only usable on the local host.
Returns:
True if this config is for a local component, False otherwise.
"""
return True
is_local: bool
property
readonly
Checks if this stack component is running locally.
This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
MLFlowModelDeployerFlavor (BaseModelDeployerFlavor)
Model deployer flavor for MLFlow models.
Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
"""Model deployer flavor for MLFlow models."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return MLFLOW_MODEL_DEPLOYER_FLAVOR
@property
def config_class(self) -> Type[MLFlowModelDeployerConfig]:
"""Returns `MLFlowModelDeployerConfig` config class.
Returns:
The config class.
"""
return MLFlowModelDeployerConfig
@property
def implementation_class(self) -> Type["MLFlowModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.mlflow.model_deployers import (
MLFlowModelDeployer,
)
return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]
property
readonly
Returns MLFlowModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] |
The config class. |
implementation_class: Type[MLFlowModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[MLFlowModelDeployer] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
mlflow_utils
Implementation of utils specific to the MLflow integration.
get_missing_mlflow_experiment_tracker_error()
Returns description of how to add an MLflow experiment tracker to your stack.
Returns:
Type | Description |
---|---|
ValueError |
If no MLflow experiment tracker is registered in the active stack. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
"""Returns description of how to add an MLflow experiment tracker to your stack.
Returns:
ValueError: If no MLflow experiment tracker is registered in the active stack.
"""
return ValueError(
"The active stack needs to have a MLflow experiment tracker "
"component registered to be able to track experiments using "
"MLflow. You can create a new stack with a MLflow experiment "
"tracker component or update your existing stack to add this "
"component, e.g.:\n\n"
" 'zenml experiment-tracker register mlflow_tracker "
"--type=mlflow'\n"
" 'zenml stack register stack-name -e mlflow_tracker ...'\n"
)
get_tracking_uri()
Gets the MLflow tracking URI from the active experiment tracking stack component.
noqa: DAR401
Returns:
Type | Description |
---|---|
str |
MLflow tracking URI. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
"""Gets the MLflow tracking URI from the active experiment tracking stack component.
# noqa: DAR401
Returns:
MLflow tracking URI.
"""
from zenml.integrations.mlflow.experiment_trackers.mlflow_experiment_tracker import (
MLFlowExperimentTracker,
)
tracker = Client().active_stack.experiment_tracker
if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
raise get_missing_mlflow_experiment_tracker_error()
return tracker.get_tracking_uri()
is_zenml_run(run)
Checks if a MLflow run is a ZenML run or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run |
Run |
The run to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the run is a ZenML run. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def is_zenml_run(run: Run) -> bool:
"""Checks if a MLflow run is a ZenML run or not.
Args:
run: The run to check.
Returns:
If the run is a ZenML run.
"""
return ZENML_TAG_KEY in run.data.tags
stop_zenml_mlflow_runs()
Stops active ZenML Mlflow runs.
This function stops all MLflow active runs until no active run exists or a non-ZenML run is active.
Source code in zenml/integrations/mlflow/mlflow_utils.py
def stop_zenml_mlflow_runs() -> None:
"""Stops active ZenML Mlflow runs.
This function stops all MLflow active runs until no active run exists or
a non-ZenML run is active.
"""
active_run = mlflow.active_run()
while active_run:
if is_zenml_run(active_run):
logger.debug("Stopping mlflow run %s.", active_run.info.run_id)
mlflow.end_run()
active_run = mlflow.active_run()
else:
break
model_deployers
special
Initialization of the MLflow model deployers.
mlflow_model_deployer
Implementation of the MLflow model deployer.
MLFlowModelDeployer (BaseModelDeployer)
MLflow implementation of the BaseModelDeployer.
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
"""MLflow implementation of the BaseModelDeployer."""
_service_path: Optional[str] = None
@property
def config(self) -> MLFlowModelDeployerConfig:
"""Returns the `MLFlowModelDeployerConfig` config.
Returns:
The configuration.
"""
return cast(MLFlowModelDeployerConfig, self._config)
@staticmethod
def get_service_path(id_: UUID) -> str:
"""Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files
are stored.
Args:
id_: The ID of the MLflow model deployer.
Returns:
The service path.
"""
service_path = os.path.join(
GlobalConfiguration().local_stores_path,
str(id_),
)
create_dir_recursive_if_not_exists(service_path)
return service_path
@property
def local_path(self) -> str:
"""Returns the path to the root directory.
This is where all configurations for MLflow deployment daemon processes
are stored.
If the service path is not set in the config by the user, the path is
set to a local default path according to the component ID.
Returns:
The path to the local service root directory.
"""
if self._service_path is not None:
return self._service_path
if self.config.service_path:
self._service_path = self.config.service_path
else:
self._service_path = self.get_service_path(self.id)
create_dir_recursive_if_not_exists(self._service_path)
return self._service_path
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
A dictionary containing the information.
"""
return {
"PREDICTION_URL": service_instance.endpoint.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SERVICE_PATH": service_instance.status.runtime_path,
"DAEMON_PID": str(service_instance.status.pid),
}
@staticmethod
def get_active_model_deployer() -> "MLFlowModelDeployer":
"""Returns the MLFlowModelDeployer component of the active stack.
Args:
None
Returns:
The MLFlowModelDeployer component of the active stack.
Raises:
TypeError: If the active stack does not contain an MLFlowModelDeployer component.
"""
model_deployer = Client( # type: ignore[call-arg]
skip_client_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, MLFlowModelDeployer
):
raise TypeError(
f"The active stack needs to have an MLflow model deployer "
f"component registered to be able to deploy models with MLflow. "
f"You can create a new stack with an MLflow model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
f" 'zenml stack create stack-name -d mlflow ...'\n"
)
return model_deployer
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new MLflow
deployment server to reflect the model and other configuration
parameters specified in the supplied MLflow service `config`.
* if `replace` is True, this method will first attempt to find an
existing MLflow deployment service that is *equivalent* to the
supplied configuration parameters. Two or more MLflow deployment
services are considered equivalent if they have the same
`pipeline_name`, `pipeline_step_name` and `model_name` configuration
parameters. To put it differently, two MLflow deployment services
are equivalent if they serve versions of the same model deployed by
the same pipeline step. If an equivalent MLflow deployment is found,
it will be updated in place to reflect the new configuration
parameters.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Args:
config: the configuration of the model to be deployed with MLflow.
replace: set this flag to True to find and update an equivalent
MLflow deployment server with the new model instead of
creating and starting a new deployment server.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the MLflow
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML MLflow deployment service object that can be used to
interact with the MLflow model server.
"""
config = cast(MLFlowDeploymentConfig, config)
service = None
# if replace is True, remove all existing services
if replace is True:
existing_services = self.find_model_server(
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for existing_service in existing_services:
if service is None:
# keep the most recently created service
service = cast(MLFlowDeploymentService, existing_service)
try:
# delete the older services and don't wait for them to
# be deprovisioned
self._clean_up_existing_service(
existing_service=cast(
MLFlowDeploymentService, existing_service
),
timeout=timeout,
force=True,
)
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
logger.info(
f"Updating an existing MLflow deployment service: {service}"
)
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
service.stop(timeout=timeout, force=True)
service.update(config)
service.start(timeout=timeout)
else:
# create a new MLFlowDeploymentService instance
service = self._create_new_service(timeout, config)
logger.info(f"Created a new MLflow deployment service: {service}")
return cast(BaseService, service)
def _clean_up_existing_service(
self,
timeout: int,
force: bool,
existing_service: MLFlowDeploymentService,
) -> None:
# stop the older service
existing_service.stop(timeout=timeout, force=force)
# delete the old configuration file
service_directory_path = existing_service.status.runtime_path or ""
shutil.rmtree(service_directory_path)
# the step will receive a config from the user that mentions the number
# of workers etc.the step implementation will create a new config using
# all values from the user and add values like pipeline name, model_uri
def _create_new_service(
self, timeout: int, config: MLFlowDeploymentConfig
) -> MLFlowDeploymentService:
"""Creates a new MLFlowDeploymentService.
Args:
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated.
config: the configuration of the model to be deployed with MLflow.
Returns:
The MLFlowDeploymentService object that can be used to interact
with the MLflow model server.
"""
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
# create a new service for the new model
service = MLFlowDeploymentService(config)
service.start(timeout=timeout)
return service
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Finds one or more model servers that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: Name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model
was part of.
pipeline_step_name: The name of the pipeline model deployment step
that deployed the model.
model_name: Name of the deployed model.
model_uri: URI of the deployed model.
model_type: Type/format of the deployed model. Not used in this
MLflow case.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
Raises:
TypeError: if any of the input arguments are of an invalid type.
"""
services = []
config = MLFlowDeploymentConfig(
model_name=model_name or "",
model_uri=model_uri or "",
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
)
# find all services that match the input criteria
for root, _, files in os.walk(self.local_path):
if service_uuid and Path(root).name != str(service_uuid):
continue
for file in files:
if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
service_config_path = os.path.join(root, file)
logger.debug(
"Loading service daemon configuration from %s",
service_config_path,
)
existing_service_config = None
with open(service_config_path, "r") as f:
existing_service_config = f.read()
existing_service = ServiceRegistry().load_service_from_json(
existing_service_config
)
if not isinstance(
existing_service, MLFlowDeploymentService
):
raise TypeError(
f"Expected service type MLFlowDeploymentService but got "
f"{type(existing_service)} instead"
)
existing_service.update_status()
if self._matches_search_criteria(existing_service, config):
if not running or existing_service.is_running:
services.append(cast(BaseService, existing_service))
return services
def _matches_search_criteria(
self,
existing_service: MLFlowDeploymentService,
config: MLFlowDeploymentConfig,
) -> bool:
"""Returns true if a service matches the input criteria.
If any of the values in the input criteria are None, they are ignored.
This allows listing services just by common pipeline names or step
names, etc.
Args:
existing_service: The materialized Service instance derived from
the config of the older (existing) service
config: The MLFlowDeploymentConfig object passed to the
deploy_model function holding parameters of the new service
to be created.
Returns:
True if the service matches the input criteria.
"""
existing_service_config = existing_service.config
# check if the existing service matches the input criteria
if (
(
not config.pipeline_name
or existing_service_config.pipeline_name == config.pipeline_name
)
and (
not config.model_name
or existing_service_config.model_name == config.model_name
)
and (
not config.pipeline_step_name
or existing_service_config.pipeline_step_name
== config.pipeline_step_name
)
and (
not config.pipeline_run_id
or existing_service_config.pipeline_run_id
== config.pipeline_run_id
)
):
return True
return False
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to stop a model server.
Args:
uuid: UUID of the model server to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, stop it
if existing_services:
existing_services[0].stop(timeout=timeout, force=force)
def start_model_server(
self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
"""Method to start a model server.
Args:
uuid: UUID of the model server to start.
timeout: Timeout in seconds to wait for the service to start.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, start it
if existing_services:
existing_services[0].start(timeout=timeout)
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
uuid: UUID of the model server to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, clean it up
if existing_services:
service = cast(MLFlowDeploymentService, existing_services[0])
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
config: MLFlowModelDeployerConfig
property
readonly
Returns the MLFlowModelDeployerConfig
config.
Returns:
Type | Description |
---|---|
MLFlowModelDeployerConfig |
The configuration. |
local_path: str
property
readonly
Returns the path to the root directory.
This is where all configurations for MLflow deployment daemon processes are stored.
If the service path is not set in the config by the user, the path is set to a local default path according to the component ID.
Returns:
Type | Description |
---|---|
str |
The path to the local service root directory. |
delete_model_server(self, uuid, timeout=10, force=False)
Method to delete all configuration of a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to delete. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
10 |
force |
bool |
If True, force the service to stop. |
False |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
uuid: UUID of the model server to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, clean it up
if existing_services:
service = cast(MLFlowDeploymentService, existing_services[0])
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
deploy_model(self, config, replace=False, timeout=10)
Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow serviceconfig
. -
if
replace
is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServiceConfig |
the configuration of the model to be deployed with MLflow. |
required |
replace |
bool |
set this flag to True to find and update an equivalent MLflow deployment server with the new model instead of creating and starting a new deployment server. |
False |
timeout |
int |
the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start. |
10 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML MLflow deployment service object that can be used to interact with the MLflow model server. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new MLflow
deployment server to reflect the model and other configuration
parameters specified in the supplied MLflow service `config`.
* if `replace` is True, this method will first attempt to find an
existing MLflow deployment service that is *equivalent* to the
supplied configuration parameters. Two or more MLflow deployment
services are considered equivalent if they have the same
`pipeline_name`, `pipeline_step_name` and `model_name` configuration
parameters. To put it differently, two MLflow deployment services
are equivalent if they serve versions of the same model deployed by
the same pipeline step. If an equivalent MLflow deployment is found,
it will be updated in place to reflect the new configuration
parameters.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Args:
config: the configuration of the model to be deployed with MLflow.
replace: set this flag to True to find and update an equivalent
MLflow deployment server with the new model instead of
creating and starting a new deployment server.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the MLflow
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML MLflow deployment service object that can be used to
interact with the MLflow model server.
"""
config = cast(MLFlowDeploymentConfig, config)
service = None
# if replace is True, remove all existing services
if replace is True:
existing_services = self.find_model_server(
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for existing_service in existing_services:
if service is None:
# keep the most recently created service
service = cast(MLFlowDeploymentService, existing_service)
try:
# delete the older services and don't wait for them to
# be deprovisioned
self._clean_up_existing_service(
existing_service=cast(
MLFlowDeploymentService, existing_service
),
timeout=timeout,
force=True,
)
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
logger.info(
f"Updating an existing MLflow deployment service: {service}"
)
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
service.stop(timeout=timeout, force=True)
service.update(config)
service.start(timeout=timeout)
else:
# create a new MLFlowDeploymentService instance
service = self._create_new_service(timeout, config)
logger.info(f"Created a new MLflow deployment service: {service}")
return cast(BaseService, service)
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)
Finds one or more model servers that match the given criteria.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running |
bool |
If true, only running services will be returned. |
False |
service_uuid |
Optional[uuid.UUID] |
The UUID of the service that was originally used to deploy the model. |
None |
pipeline_name |
Optional[str] |
Name of the pipeline that the deployed model was part of. |
None |
pipeline_run_id |
Optional[str] |
ID of the pipeline run which the deployed model was part of. |
None |
pipeline_step_name |
Optional[str] |
The name of the pipeline model deployment step that deployed the model. |
None |
model_name |
Optional[str] |
Name of the deployed model. |
None |
model_uri |
Optional[str] |
URI of the deployed model. |
None |
model_type |
Optional[str] |
Type/format of the deployed model. Not used in this MLflow case. |
None |
Returns:
Type | Description |
---|---|
List[zenml.services.service.BaseService] |
One or more Service objects representing model servers that match the input search criteria. |
Exceptions:
Type | Description |
---|---|
TypeError |
if any of the input arguments are of an invalid type. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Finds one or more model servers that match the given criteria.
Args:
running: If true, only running services will be returned.
service_uuid: The UUID of the service that was originally used
to deploy the model.
pipeline_name: Name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model
was part of.
pipeline_step_name: The name of the pipeline model deployment step
that deployed the model.
model_name: Name of the deployed model.
model_uri: URI of the deployed model.
model_type: Type/format of the deployed model. Not used in this
MLflow case.
Returns:
One or more Service objects representing model servers that match
the input search criteria.
Raises:
TypeError: if any of the input arguments are of an invalid type.
"""
services = []
config = MLFlowDeploymentConfig(
model_name=model_name or "",
model_uri=model_uri or "",
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
)
# find all services that match the input criteria
for root, _, files in os.walk(self.local_path):
if service_uuid and Path(root).name != str(service_uuid):
continue
for file in files:
if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
service_config_path = os.path.join(root, file)
logger.debug(
"Loading service daemon configuration from %s",
service_config_path,
)
existing_service_config = None
with open(service_config_path, "r") as f:
existing_service_config = f.read()
existing_service = ServiceRegistry().load_service_from_json(
existing_service_config
)
if not isinstance(
existing_service, MLFlowDeploymentService
):
raise TypeError(
f"Expected service type MLFlowDeploymentService but got "
f"{type(existing_service)} instead"
)
existing_service.update_status()
if self._matches_search_criteria(existing_service, config):
if not running or existing_service.is_running:
services.append(cast(BaseService, existing_service))
return services
get_active_model_deployer()
staticmethod
Returns the MLFlowModelDeployer component of the active stack.
Returns:
Type | Description |
---|---|
MLFlowModelDeployer |
The MLFlowModelDeployer component of the active stack. |
Exceptions:
Type | Description |
---|---|
TypeError |
If the active stack does not contain an MLFlowModelDeployer component. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "MLFlowModelDeployer":
"""Returns the MLFlowModelDeployer component of the active stack.
Args:
None
Returns:
The MLFlowModelDeployer component of the active stack.
Raises:
TypeError: If the active stack does not contain an MLFlowModelDeployer component.
"""
model_deployer = Client( # type: ignore[call-arg]
skip_client_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, MLFlowModelDeployer
):
raise TypeError(
f"The active stack needs to have an MLflow model deployer "
f"component registered to be able to deploy models with MLflow. "
f"You can create a new stack with an MLflow model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
f" 'zenml stack create stack-name -d mlflow ...'\n"
)
return model_deployer
get_model_server_info(service_instance)
staticmethod
Return implementation specific information relevant to the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
MLFlowDeploymentService |
Instance of a SeldonDeploymentService |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
A dictionary containing the information. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
A dictionary containing the information.
"""
return {
"PREDICTION_URL": service_instance.endpoint.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SERVICE_PATH": service_instance.status.runtime_path,
"DAEMON_PID": str(service_instance.status.pid),
}
get_service_path(id_)
staticmethod
Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files are stored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id_ |
UUID |
The ID of the MLflow model deployer. |
required |
Returns:
Type | Description |
---|---|
str |
The service path. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(id_: UUID) -> str:
"""Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files
are stored.
Args:
id_: The ID of the MLflow model deployer.
Returns:
The service path.
"""
service_path = os.path.join(
GlobalConfiguration().local_stores_path,
str(id_),
)
create_dir_recursive_if_not_exists(service_path)
return service_path
start_model_server(self, uuid, timeout=10)
Method to start a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to start. |
required |
timeout |
int |
Timeout in seconds to wait for the service to start. |
10 |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def start_model_server(
self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
"""Method to start a model server.
Args:
uuid: UUID of the model server to start.
timeout: Timeout in seconds to wait for the service to start.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, start it
if existing_services:
existing_services[0].start(timeout=timeout)
stop_model_server(self, uuid, timeout=10, force=False)
Method to stop a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to stop. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
10 |
force |
bool |
If True, force the service to stop. |
False |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to stop a model server.
Args:
uuid: UUID of the model server to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
# get list of all services
existing_services = self.find_model_server(service_uuid=uuid)
# if the service exists, stop it
if existing_services:
existing_services[0].stop(timeout=timeout, force=force)
services
special
Initialization of the MLflow Service.
mlflow_deployment
Implementation of the MLflow deployment functionality.
MLFlowDeploymentConfig (LocalDaemonServiceConfig)
pydantic-model
MLflow model deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the MLflow model to serve |
model_name |
str |
the name of the model |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
"""MLflow model deployment configuration.
Attributes:
model_uri: URI of the MLflow model to serve
model_name: the name of the model
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
"""
model_uri: str
model_name: str
workers: int = 1
mlserver: bool = False
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint)
pydantic-model
A service endpoint exposed by the MLflow deployment daemon.
Attributes:
Name | Type | Description |
---|---|---|
config |
MLFlowDeploymentEndpointConfig |
service endpoint configuration |
monitor |
HTTPEndpointHealthMonitor |
optional service endpoint health monitor |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
"""A service endpoint exposed by the MLflow deployment daemon.
Attributes:
config: service endpoint configuration
monitor: optional service endpoint health monitor
"""
config: MLFlowDeploymentEndpointConfig
monitor: HTTPEndpointHealthMonitor
@property
def prediction_url(self) -> Optional[str]:
"""Gets the prediction URL for the endpoint.
Returns:
the prediction URL for the endpoint
"""
uri = self.status.uri
if not uri:
return None
return os.path.join(uri, self.config.prediction_url_path)
prediction_url: Optional[str]
property
readonly
Gets the prediction URL for the endpoint.
Returns:
Type | Description |
---|---|
Optional[str] |
the prediction URL for the endpoint |
MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig)
pydantic-model
MLflow daemon service endpoint configuration.
Attributes:
Name | Type | Description |
---|---|---|
prediction_url_path |
str |
URI subpath for prediction requests |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
"""MLflow daemon service endpoint configuration.
Attributes:
prediction_url_path: URI subpath for prediction requests
"""
prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService)
pydantic-model
MLflow deployment service used to start a local prediction server for MLflow models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the MLflow deployment service class |
config |
MLFlowDeploymentConfig |
service configuration |
endpoint |
MLFlowDeploymentEndpoint |
optional service endpoint |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService):
"""MLflow deployment service used to start a local prediction server for MLflow models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the MLflow deployment service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="mlflow-deployment",
type="model-serving",
flavor="mlflow",
description="MLflow prediction service",
)
config: MLFlowDeploymentConfig
endpoint: MLFlowDeploymentEndpoint
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialize the MLflow deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_url_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_url_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_url_path=prediction_url_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
"""Start the service."""
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
serve_kwargs: Dict[str, Any] = {}
# MLflow version 1.26 introduces an additional mandatory
# `timeout` argument to the `PyFuncBackend.serve` function
if int(MLFLOW_VERSION.split(".")[1]) >= 26:
serve_kwargs["timeout"] = None
backend = PyFuncBackend(
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
)
backend.serve(
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
**serve_kwargs,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
@property
def prediction_url(self) -> Optional[str]:
"""Get the URI where the prediction service is answering requests.
Returns:
The URI where the prediction service can be contacted to process
HTTP/REST inference requests, or None, if the service isn't running.
"""
if not self.is_running:
return None
return self.endpoint.prediction_url
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not running
ValueError: if the prediction endpoint is unknown.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
if self.endpoint.prediction_url is not None:
response = requests.post(
self.endpoint.prediction_url,
json={"instances": request.tolist()},
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
return np.array(response.json())
prediction_url: Optional[str]
property
readonly
Get the URI where the prediction service is answering requests.
Returns:
Type | Description |
---|---|
Optional[str] |
The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running. |
__init__(self, config, **attrs)
special
Initialize the MLflow deployment service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]] |
service configuration |
required |
attrs |
Any |
additional attributes to set on the service |
{} |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialize the MLflow deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_url_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_url_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_url_path=prediction_url_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
NDArray[Any] |
a numpy array representing the request |
required |
Returns:
Type | Description |
---|---|
NDArray[Any] |
A numpy array representing the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not running |
ValueError |
if the prediction endpoint is unknown. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not running
ValueError: if the prediction endpoint is unknown.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
if self.endpoint.prediction_url is not None:
response = requests.post(
self.endpoint.prediction_url,
json={"instances": request.tolist()},
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
return np.array(response.json())
run(self)
Start the service.
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
"""Start the service."""
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
serve_kwargs: Dict[str, Any] = {}
# MLflow version 1.26 introduces an additional mandatory
# `timeout` argument to the `PyFuncBackend.serve` function
if int(MLFLOW_VERSION.split(".")[1]) >= 26:
serve_kwargs["timeout"] = None
backend = PyFuncBackend(
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
)
backend.serve(
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
**serve_kwargs,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
steps
special
Initialization of the MLflow standard interface steps.
mlflow_deployer
Implementation of the MLflow model deployer pipeline step.
MLFlowDeployerParameters (BaseParameters)
pydantic-model
Model deployer step parameters for MLflow.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
the name of the MLflow model logged in the MLflow artifact store for the current pipeline. |
experiment_name |
Optional[str] |
Name of the MLflow experiment in which the model was logged. |
run_name |
Optional[str] |
Name of the MLflow run in which the model was logged. |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
timeout |
int |
the number of seconds to wait for the service to start/stop. |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerParameters(BaseParameters):
"""Model deployer step parameters for MLflow.
Attributes:
model_name: the name of the MLflow model logged in the MLflow artifact
store for the current pipeline.
experiment_name: Name of the MLflow experiment in which the model was
logged.
run_name: Name of the MLflow run in which the model was logged.
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
timeout: the number of seconds to wait for the service to start/stop.
"""
model_name: str = "model"
experiment_name: Optional[str] = None
run_name: Optional[str] = None
workers: int = 1
mlserver: bool = False
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
mlflow_model_deployer_step (BaseStep)
Model deployer pipeline step for MLflow.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
model |
the model artifact to deploy |
required | |
params |
parameters for the deployer step |
required |
Returns:
Type | Description |
---|---|
MLflow deployment service |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Model deployer step parameters for MLflow.
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str |
the name of the MLflow model logged in the MLflow artifact store for the current pipeline. |
experiment_name |
Optional[str] |
Name of the MLflow experiment in which the model was logged. |
run_name |
Optional[str] |
Name of the MLflow run in which the model was logged. |
workers |
int |
number of workers to use for the prediction service |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
timeout |
int |
the number of seconds to wait for the service to start/stop. |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerParameters(BaseParameters):
"""Model deployer step parameters for MLflow.
Attributes:
model_name: the name of the MLflow model logged in the MLflow artifact
store for the current pipeline.
experiment_name: Name of the MLflow experiment in which the model was
logged.
run_name: Name of the MLflow run in which the model was logged.
workers: number of workers to use for the prediction service
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
timeout: the number of seconds to wait for the service to start/stop.
"""
model_name: str = "model"
experiment_name: Optional[str] = None
run_name: Optional[str] = None
workers: int = 1
mlserver: bool = False
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
entrypoint(deploy_decision, model, params)
staticmethod
Model deployer pipeline step for MLflow.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
params |
MLFlowDeployerParameters |
parameters for the deployer step |
required |
Returns:
Type | Description |
---|---|
MLFlowDeploymentService |
MLflow deployment service |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
@step(enable_cache=False)
def mlflow_model_deployer_step(
deploy_decision: bool,
model: ModelArtifact,
params: MLFlowDeployerParameters,
) -> MLFlowDeploymentService:
"""Model deployer pipeline step for MLflow.
# noqa: DAR401
Args:
deploy_decision: whether to deploy the model or not
model: the model artifact to deploy
params: parameters for the deployer step
Returns:
MLflow deployment service
"""
model_deployer = MLFlowModelDeployer.get_active_model_deployer()
# fetch the MLflow artifacts logged during the pipeline run
experiment_tracker = Client( # type: ignore[call-arg]
skip_client_check=True
).active_stack.experiment_tracker
if not isinstance(experiment_tracker, MLFlowExperimentTracker):
raise get_missing_mlflow_experiment_tracker_error()
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
run_id = step_env.pipeline_run_id
step_name = step_env.step_name
client = MlflowClient()
mlflow_run_id = experiment_tracker.get_run_id(
experiment_name=params.experiment_name or pipeline_name,
run_name=params.run_name or run_id,
)
model_uri = ""
if mlflow_run_id and client.list_artifacts(
mlflow_run_id, params.model_name
):
model_uri = artifact_utils.get_artifact_uri(
run_id=mlflow_run_id, artifact_path=params.model_name
)
# fetch existing services with same pipeline name, step name and model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=params.model_name,
)
# create a config for the new model service
predictor_cfg = MLFlowDeploymentConfig(
model_name=params.model_name or "",
model_uri=model_uri,
workers=params.workers,
mlserver=params.mlserver,
pipeline_name=pipeline_name,
pipeline_run_id=run_id,
pipeline_step_name=step_name,
)
# Creating a new service with inactive state and status by default
service = MLFlowDeploymentService(predictor_cfg)
if existing_services:
service = cast(MLFlowDeploymentService, existing_services[0])
# check for conditions to deploy the model
if not model_uri:
# an MLflow model was not trained in the current run, so we simply reuse
# the currently running service created for the same model, if any
if not existing_services:
logger.warning(
f"An MLflow model with name `{params.model_name}` was not "
f"logged in the current pipeline run and no running MLflow "
f"model server was found. Please ensure that your pipeline "
f"includes a step with a MLflow experiment configured that "
"trains a model and logs it to MLflow. This could also happen "
"if the current pipeline run did not log an MLflow model "
f"because the training step was cached."
)
# return an inactive service just because we have to return
# something
return service
logger.info(
f"An MLflow model with name `{params.model_name}` was not "
f"trained in the current pipeline run. Reusing the existing "
f"MLflow model server."
)
if not service.is_running:
service.start(params.timeout)
# return the existing service
return service
# even when the deploy decision is negative, if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{params.model_name}'..."
)
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure
# that a model server is available at all times
if not service.is_running:
service.start(params.timeout)
return service
# create a new model deployment and replace an old one if it exists
new_service = cast(
MLFlowDeploymentService,
model_deployer.deploy_model(
replace=True,
config=predictor_cfg,
timeout=params.timeout,
),
)
logger.info(
f"MLflow deployment service started and reachable at:\n"
f" {new_service.prediction_url}\n"
)
return new_service
mlflow_deployer_step(enable_cache=True, name=None)
Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.
The returned step can be used in a pipeline to implement continuous deployment for an MLflow model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
enable_cache |
bool |
Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default |
True |
name |
Optional[str] |
Name of the step. |
None |
Returns:
Type | Description |
---|---|
Type[zenml.steps.base_step.BaseStep] |
an MLflow model deployer pipeline step |
Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
def mlflow_deployer_step(
enable_cache: bool = True,
name: Optional[str] = None,
) -> Type[BaseStep]:
"""Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.
The returned step can be used in a pipeline to implement continuous
deployment for an MLflow model.
Args:
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default
name: Name of the step.
Returns:
an MLflow model deployer pipeline step
"""
logger.warning(
"The `mlflow_deployer_step` function is deprecated. Please "
"use the built-in `mlflow_model_deployer_step` step instead."
)
return mlflow_model_deployer_step
neural_prophet
special
Initialization of the Neural Prophet integration.
NeuralProphetIntegration (Integration)
Definition of NeuralProphet integration for ZenML.
Source code in zenml/integrations/neural_prophet/__init__.py
class NeuralProphetIntegration(Integration):
"""Definition of NeuralProphet integration for ZenML."""
NAME = NEURAL_PROPHET
REQUIREMENTS = ["neuralprophet>=0.3.2"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.neural_prophet import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/neural_prophet/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.neural_prophet import materializers # noqa
materializers
special
Initialization of the Neural Prophet materializer.
neural_prophet_materializer
Implementation of the Neural Prophet materializer.
NeuralProphetMaterializer (BaseMaterializer)
Materializer to read/write NeuralProphet models.
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
class NeuralProphetMaterializer(BaseMaterializer):
"""Materializer to read/write NeuralProphet models."""
ASSOCIATED_TYPES = (NeuralProphet,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
"""Reads and returns a NeuralProphet model.
Args:
data_type: A NeuralProphet model object.
Returns:
A loaded NeuralProphet model.
"""
super().handle_input(data_type)
return torch.load( # type: ignore[no-untyped-call]
os.path.join(self.artifact.uri, DEFAULT_FILENAME)
) # noqa
def handle_return(self, model: NeuralProphet) -> None:
"""Writes a NeuralProphet model.
Args:
model: A NeuralProphet model object.
"""
super().handle_return(model)
torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
handle_input(self, data_type)
Reads and returns a NeuralProphet model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A NeuralProphet model object. |
required |
Returns:
Type | Description |
---|---|
NeuralProphet |
A loaded NeuralProphet model. |
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
"""Reads and returns a NeuralProphet model.
Args:
data_type: A NeuralProphet model object.
Returns:
A loaded NeuralProphet model.
"""
super().handle_input(data_type)
return torch.load( # type: ignore[no-untyped-call]
os.path.join(self.artifact.uri, DEFAULT_FILENAME)
) # noqa
handle_return(self, model)
Writes a NeuralProphet model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
NeuralProphet |
A NeuralProphet model object. |
required |
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_return(self, model: NeuralProphet) -> None:
"""Writes a NeuralProphet model.
Args:
model: A NeuralProphet model object.
"""
super().handle_return(model)
torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
pillow
special
Initialization of the Pillow integration.
PillowIntegration (Integration)
Definition of Pillow integration for ZenML.
Source code in zenml/integrations/pillow/__init__.py
class PillowIntegration(Integration):
"""Definition of Pillow integration for ZenML."""
NAME = PILLOW
REQUIREMENTS = ["Pillow>=9.2.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pillow import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pillow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pillow import materializers # noqa
materializers
special
Initialization of the Pillow materializer.
pillow_image_materializer
Materializer for Pillow Image objects.
PillowImageMaterializer (BaseMaterializer)
Materializer for Image.Image objects.
This materializer takes a PIL image object and returns a PIL image object. It handles all the source image formats supported by PIL as listed here: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html.
Source code in zenml/integrations/pillow/materializers/pillow_image_materializer.py
class PillowImageMaterializer(BaseMaterializer):
"""Materializer for Image.Image objects.
This materializer takes a PIL image object and returns a PIL image object.
It handles all the source image formats supported by PIL as listed here:
https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html.
"""
ASSOCIATED_TYPES = (Image.Image,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Image.Image]) -> Image.Image:
"""Read from artifact store.
Args:
data_type: An Image.Image type.
Returns:
An Image.Image object.
"""
super().handle_input(data_type)
files = io_utils.find_files(
self.artifact.uri, f"{DEFAULT_IMAGE_FILENAME}.*"
)
filepath = [file for file in files if not fileio.isdir(file)][0]
# # FAILING OPTION 1: temporary directory
# # create a temporary folder
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
temp_file = os.path.join(
temp_dir.name,
f"{DEFAULT_IMAGE_FILENAME}{os.path.splitext(filepath)[1]}",
)
# copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
return Image.open(temp_file)
def handle_return(self, image: Image.Image) -> None:
"""Write to artifact store.
Args:
image: An Image.Image object.
"""
# # FAILING OPTION 1: temporary directory
super().handle_return(image)
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
file_extension = image.format or DEFAULT_IMAGE_EXTENSION
full_filename = f"{DEFAULT_IMAGE_FILENAME}.{file_extension}"
temp_image_path = os.path.join(temp_dir.name, full_filename)
# save the image in a temporary directory
image.save(temp_image_path)
# copy the saved image to the artifact store
artifact_store_path = os.path.join(self.artifact.uri, full_filename)
io_utils.copy(temp_image_path, artifact_store_path, overwrite=True) # type: ignore[attr-defined]
temp_dir.cleanup()
handle_input(self, data_type)
Read from artifact store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[PIL.Image.Image] |
An Image.Image type. |
required |
Returns:
Type | Description |
---|---|
Image |
An Image.Image object. |
Source code in zenml/integrations/pillow/materializers/pillow_image_materializer.py
def handle_input(self, data_type: Type[Image.Image]) -> Image.Image:
"""Read from artifact store.
Args:
data_type: An Image.Image type.
Returns:
An Image.Image object.
"""
super().handle_input(data_type)
files = io_utils.find_files(
self.artifact.uri, f"{DEFAULT_IMAGE_FILENAME}.*"
)
filepath = [file for file in files if not fileio.isdir(file)][0]
# # FAILING OPTION 1: temporary directory
# # create a temporary folder
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
temp_file = os.path.join(
temp_dir.name,
f"{DEFAULT_IMAGE_FILENAME}{os.path.splitext(filepath)[1]}",
)
# copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
return Image.open(temp_file)
handle_return(self, image)
Write to artifact store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
Image |
An Image.Image object. |
required |
Source code in zenml/integrations/pillow/materializers/pillow_image_materializer.py
def handle_return(self, image: Image.Image) -> None:
"""Write to artifact store.
Args:
image: An Image.Image object.
"""
# # FAILING OPTION 1: temporary directory
super().handle_return(image)
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
file_extension = image.format or DEFAULT_IMAGE_EXTENSION
full_filename = f"{DEFAULT_IMAGE_FILENAME}.{file_extension}"
temp_image_path = os.path.join(temp_dir.name, full_filename)
# save the image in a temporary directory
image.save(temp_image_path)
# copy the saved image to the artifact store
artifact_store_path = os.path.join(self.artifact.uri, full_filename)
io_utils.copy(temp_image_path, artifact_store_path, overwrite=True) # type: ignore[attr-defined]
temp_dir.cleanup()
plotly
special
Initialization of the Plotly integration.
PlotlyIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/plotly/__init__.py
class PlotlyIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = PLOTLY
REQUIREMENTS = ["plotly>=5.4.0"]
visualizers
special
Initialization of the Plotly Visualizer.
pipeline_lineage_visualizer
Implementation of the Plotly Pipeline Lineage Visualizer.
PipelineLineageVisualizer (BaseVisualizer)
Visualize the lineage of runs in a pipeline using plotly.
Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
class PipelineLineageVisualizer(BaseVisualizer):
"""Visualize the lineage of runs in a pipeline using plotly."""
@abstractmethod
def visualize(
self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
"""Creates a pipeline lineage diagram using plotly.
Args:
object: The pipeline view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A plotly figure.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
category_dict = {}
dimensions = ["run"]
for run in object.runs:
category_dict[run.name] = {"run": run.name}
for step in run.steps:
category_dict[run.name].update(
{
step.entrypoint_name: str(step.id),
}
)
if step.entrypoint_name not in dimensions:
dimensions.append(f"{step.entrypoint_name}")
category_df = pd.DataFrame.from_dict(category_dict, orient="index")
category_df = category_df.reset_index()
fig = px.parallel_categories(
category_df,
dimensions,
color=None,
labels="status",
)
fig.show()
return fig
visualize(self, object, *args, **kwargs)
Creates a pipeline lineage diagram using plotly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineView |
The pipeline view to visualize. |
required |
*args |
Any |
Additional arguments to pass to the visualization. |
() |
**kwargs |
Any |
Additional keyword arguments to pass to the visualization. |
{} |
Returns:
Type | Description |
---|---|
Figure |
A plotly figure. |
Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
@abstractmethod
def visualize(
self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
"""Creates a pipeline lineage diagram using plotly.
Args:
object: The pipeline view to visualize.
*args: Additional arguments to pass to the visualization.
**kwargs: Additional keyword arguments to pass to the visualization.
Returns:
A plotly figure.
"""
logger.warning(
"This integration is not completed yet. Results might be unexpected."
)
category_dict = {}
dimensions = ["run"]
for run in object.runs:
category_dict[run.name] = {"run": run.name}
for step in run.steps:
category_dict[run.name].update(
{
step.entrypoint_name: str(step.id),
}
)
if step.entrypoint_name not in dimensions:
dimensions.append(f"{step.entrypoint_name}")
category_df = pd.DataFrame.from_dict(category_dict, orient="index")
category_df = category_df.reset_index()
fig = px.parallel_categories(
category_df,
dimensions,
color=None,
labels="status",
)
fig.show()
return fig
pytorch
special
Initialization of the PyTorch integration.
PytorchIntegration (Integration)
Definition of PyTorch integration for ZenML.
Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
"""Definition of PyTorch integration for ZenML."""
NAME = PYTORCH
REQUIREMENTS = ["torch"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
materializers
special
Initialization of the PyTorch Materializer.
pytorch_dataloader_materializer
Implementation of the PyTorch DataLoader materializer.
PyTorchDataLoaderMaterializer (BaseMaterializer)
Materializer to read/write PyTorch dataloaders.
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BaseMaterializer):
"""Materializer to read/write PyTorch dataloaders."""
ASSOCIATED_TYPES = (DataLoader,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
"""Reads and returns a PyTorch dataloader.
Args:
data_type: The type of the dataloader to load.
Returns:
A loaded PyTorch dataloader.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(DataLoader[Any], torch.load(f)) # type: ignore[no-untyped-call] # noqa
def handle_return(self, dataloader: DataLoader[Any]) -> None:
"""Writes a PyTorch dataloader.
Args:
dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
"""
super().handle_return(dataloader)
# Save entire dataloader to artifact directory
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(dataloader, f)
handle_input(self, data_type)
Reads and returns a PyTorch dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the dataloader to load. |
required |
Returns:
Type | Description |
---|---|
torch.utils.data.dataloader.DataLoader[Any] |
A loaded PyTorch dataloader. |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
"""Reads and returns a PyTorch dataloader.
Args:
data_type: The type of the dataloader to load.
Returns:
A loaded PyTorch dataloader.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(DataLoader[Any], torch.load(f)) # type: ignore[no-untyped-call] # noqa
handle_return(self, dataloader)
Writes a PyTorch dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataloader |
torch.utils.data.dataloader.DataLoader[Any] |
A torch.utils.DataLoader or a dict to pass into dataloader.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_return(self, dataloader: DataLoader[Any]) -> None:
"""Writes a PyTorch dataloader.
Args:
dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
"""
super().handle_return(dataloader)
# Save entire dataloader to artifact directory
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(dataloader, f)
pytorch_module_materializer
Implementation of the PyTorch Module materializer.
PyTorchModuleMaterializer (BaseMaterializer)
Materializer to read/write Pytorch models.
Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BaseMaterializer):
"""Materializer to read/write Pytorch models.
Inspired by the guide:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
"""
ASSOCIATED_TYPES = (Module,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Args:
data_type: The type of the model to load.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return torch.load(f) # type: ignore[no-untyped-call] # noqa
def handle_return(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(model, f)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f)
handle_input(self, data_type)
Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to load. |
required |
Returns:
Type | Description |
---|---|
Module |
A loaded pytorch model. |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_input(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Args:
data_type: The type of the model to load.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return torch.load(f) # type: ignore[no-untyped-call] # noqa
handle_return(self, model)
Writes a PyTorch model, as a model and a checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module |
A torch.nn.Module or a dict to pass into model.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_return(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(model, f)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f)
pytorch_lightning
special
Initialization of the PyTorch Lightning integration.
PytorchLightningIntegration (Integration)
Definition of PyTorch Lightning integration for ZenML.
Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
"""Definition of PyTorch Lightning integration for ZenML."""
NAME = PYTORCH_L
REQUIREMENTS = ["pytorch_lightning"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
materializers
special
Initialization of the PyTorch Lightning Materializer.
pytorch_lightning_materializer
Implementation of the PyTorch Lightning Materializer.
PyTorchLightningMaterializer (BaseMaterializer)
Materializer to read/write PyTorch models.
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
"""Materializer to read/write PyTorch models."""
ASSOCIATED_TYPES = (Trainer,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Trainer:
"""Reads and returns a PyTorch Lightning trainer.
Args:
data_type: The type of the trainer to load.
Returns:
A PyTorch Lightning trainer object.
"""
super().handle_input(data_type)
return Trainer(
resume_from_checkpoint=os.path.join(
self.artifact.uri, CHECKPOINT_NAME
)
)
def handle_return(self, trainer: Trainer) -> None:
"""Writes a PyTorch Lightning trainer.
Args:
trainer: A PyTorch Lightning trainer object.
"""
super().handle_return(trainer)
trainer.save_checkpoint(
os.path.join(self.artifact.uri, CHECKPOINT_NAME)
)
handle_input(self, data_type)
Reads and returns a PyTorch Lightning trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the trainer to load. |
required |
Returns:
Type | Description |
---|---|
Trainer |
A PyTorch Lightning trainer object. |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_input(self, data_type: Type[Any]) -> Trainer:
"""Reads and returns a PyTorch Lightning trainer.
Args:
data_type: The type of the trainer to load.
Returns:
A PyTorch Lightning trainer object.
"""
super().handle_input(data_type)
return Trainer(
resume_from_checkpoint=os.path.join(
self.artifact.uri, CHECKPOINT_NAME
)
)
handle_return(self, trainer)
Writes a PyTorch Lightning trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
Trainer |
A PyTorch Lightning trainer object. |
required |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_return(self, trainer: Trainer) -> None:
"""Writes a PyTorch Lightning trainer.
Args:
trainer: A PyTorch Lightning trainer object.
"""
super().handle_return(trainer)
trainer.save_checkpoint(
os.path.join(self.artifact.uri, CHECKPOINT_NAME)
)
registry
Implementation of a registry to track ZenML integrations.
IntegrationRegistry
Registry to keep track of ZenML Integrations.
Source code in zenml/integrations/registry.py
class IntegrationRegistry(object):
"""Registry to keep track of ZenML Integrations."""
def __init__(self) -> None:
"""Initializing the integration registry."""
self._integrations: Dict[str, Type["Integration"]] = {}
@property
def integrations(self) -> Dict[str, Type["Integration"]]:
"""Method to get integrations dictionary.
Returns:
A dict of integration key to type of `Integration`.
"""
return self._integrations
@integrations.setter
def integrations(self, i: Any) -> None:
"""Setter method for the integrations property.
Args:
i: Value to set the integrations property to.
Raises:
IntegrationError: If you try to manually set the integrations property.
"""
raise IntegrationError(
"Please do not manually change the integrations within the "
"registry. If you would like to register a new integration "
"manually, please use "
"`integration_registry.register_integration()`."
)
def register_integration(
self, key: str, type_: Type["Integration"]
) -> None:
"""Method to register an integration with a given name.
Args:
key: Name of the integration.
type_: Type of the integration.
"""
self._integrations[key] = type_
def activate_integrations(self) -> None:
"""Method to activate the integrations with are registered in the registry."""
for name, integration in self._integrations.items():
if integration.check_installation():
integration.activate()
logger.debug(f"Integration `{name}` is activated.")
else:
logger.debug(f"Integration `{name}` could not be activated.")
@property
def list_integration_names(self) -> List[str]:
"""Get a list of all possible integrations.
Returns:
A list of all possible integrations.
"""
return [name for name in self._integrations]
def select_integration_requirements(
self, integration_name: Optional[str] = None
) -> List[str]:
"""Select the requirements for a given integration or all integrations.
Args:
integration_name: Name of the integration to check.
Returns:
List of requirements for the integration.
Raises:
KeyError: If the integration is not found.
"""
if integration_name:
if integration_name in self.list_integration_names:
return self._integrations[integration_name].REQUIREMENTS
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
else:
return [
requirement
for name in self.list_integration_names
for requirement in self._integrations[name].REQUIREMENTS
]
def is_installed(self, integration_name: Optional[str] = None) -> bool:
"""Checks if all requirements for an integration are installed.
Args:
integration_name: Name of the integration to check.
Returns:
True if all requirements are installed, False otherwise.
Raises:
KeyError: If the integration is not found.
"""
if integration_name in self.list_integration_names:
return self._integrations[integration_name].check_installation()
elif not integration_name:
all_installed = [
self._integrations[item].check_installation()
for item in self.list_integration_names
]
return all(all_installed)
else:
raise KeyError(
f"Integration '{integration_name}' not found. "
f"Currently the following integrations are available: "
f"{self.list_integration_names}"
)
def get_installed_integrations(self) -> List[str]:
"""Returns list of installed integrations.
Returns:
List of installed integrations.
"""
return [
name
for name, integration in integration_registry.integrations.items()
if integration.check_installation()
]
integrations: Dict[str, Type[Integration]]
property
writable
Method to get integrations dictionary.
Returns:
Type | Description |
---|---|
Dict[str, Type[Integration]] |
A dict of integration key to type of |
list_integration_names: List[str]
property
readonly
Get a list of all possible integrations.
Returns:
Type | Description |
---|---|
List[str] |
A list of all possible integrations. |
__init__(self)
special
Initializing the integration registry.
Source code in zenml/integrations/registry.py
def __init__(self) -> None:
"""Initializing the integration registry."""
self._integrations: Dict[str, Type["Integration"]] = {}
activate_integrations(self)
Method to activate the integrations with are registered in the registry.
Source code in zenml/integrations/registry.py
def activate_integrations(self) -> None:
"""Method to activate the integrations with are registered in the registry."""
for name, integration in self._integrations.items():
if integration.check_installation():
integration.activate()
logger.debug(f"Integration `{name}` is activated.")
else:
logger.debug(f"Integration `{name}` could not be activated.")
get_installed_integrations(self)
Returns list of installed integrations.
Returns:
Type | Description |
---|---|
List[str] |
List of installed integrations. |
Source code in zenml/integrations/registry.py
def get_installed_integrations(self) -> List[str]:
"""Returns list of installed integrations.
Returns:
List of installed integrations.
"""
return [
name
for name, integration in integration_registry.integrations.items()
if integration.check_installation()
]
is_installed(self, integration_name=None)
Checks if all requirements for an integration are installed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
integration_name |
Optional[str] |
Name of the integration to check. |
None |
Returns:
Type | Description |
---|---|
bool |
True if all requirements are installed, False otherwise. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the integration is not found. |
Source code in zenml/integrations/registry.py
def is_installed(self, integration_name: Optional[str] = None) -> bool:
"""Checks if all requirements for an integration are installed.
Args:
integration_name: Name of the integration to check.
Returns:
True if all requirements are installed, False otherwise.
Raises:
KeyError: If the integration is not found.
"""
if integration_name in self.list_integration_names:
return self._integrations[integration_name].check_installation()
elif not integration_name:
all_installed = [
self._integrations[item].check_installation()
for item in self.list_integration_names
]
return all(all_installed)
else:
raise KeyError(
f"Integration '{integration_name}' not found. "
f"Currently the following integrations are available: "
f"{self.list_integration_names}"
)
register_integration(self, key, type_)
Method to register an integration with a given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Name of the integration. |
required |
type_ |
Type[Integration] |
Type of the integration. |
required |
Source code in zenml/integrations/registry.py
def register_integration(
self, key: str, type_: Type["Integration"]
) -> None:
"""Method to register an integration with a given name.
Args:
key: Name of the integration.
type_: Type of the integration.
"""
self._integrations[key] = type_
select_integration_requirements(self, integration_name=None)
Select the requirements for a given integration or all integrations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
integration_name |
Optional[str] |
Name of the integration to check. |
None |
Returns:
Type | Description |
---|---|
List[str] |
List of requirements for the integration. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the integration is not found. |
Source code in zenml/integrations/registry.py
def select_integration_requirements(
self, integration_name: Optional[str] = None
) -> List[str]:
"""Select the requirements for a given integration or all integrations.
Args:
integration_name: Name of the integration to check.
Returns:
List of requirements for the integration.
Raises:
KeyError: If the integration is not found.
"""
if integration_name:
if integration_name in self.list_integration_names:
return self._integrations[integration_name].REQUIREMENTS
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
else:
return [
requirement
for name in self.list_integration_names
for requirement in self._integrations[name].REQUIREMENTS
]
s3
special
Initialization of the S3 integration.
The S3 integration allows the use of cloud artifact stores and file operations on S3 buckets.
S3Integration (Integration)
Definition of S3 integration for ZenML.
Source code in zenml/integrations/s3/__init__.py
class S3Integration(Integration):
"""Definition of S3 integration for ZenML."""
NAME = S3
REQUIREMENTS = ["s3fs==2022.3.0"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the s3 integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.s3.flavors import S3ArtifactStoreFlavor
return [S3ArtifactStoreFlavor]
flavors()
classmethod
Declare the stack component flavors for the s3 integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/s3/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the s3 integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.s3.flavors import S3ArtifactStoreFlavor
return [S3ArtifactStoreFlavor]
artifact_stores
special
Initialization of the S3 Artifact Store.
s3_artifact_store
Implementation of the S3 Artifact Store.
S3ArtifactStore (BaseArtifactStore, AuthenticationMixin)
Artifact Store for S3 based artifacts.
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
class S3ArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for S3 based artifacts."""
_filesystem: Optional[s3fs.S3FileSystem] = None
@property
def config(self) -> S3ArtifactStoreConfig:
"""Get the config of this artifact store.
Returns:
The config of this artifact store.
"""
return cast(S3ArtifactStoreConfig, self._config)
def _get_credentials(
self,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Gets authentication credentials.
If an authentication secret is configured, the secret values are
returned. Otherwise, we fall back to the plain text component
attributes.
Returns:
Tuple (key, secret, token) of credentials used to authenticate with
the S3 filesystem.
"""
secret = self.get_authentication_secret(
expected_schema_type=AWSSecretSchema
)
if secret:
return (
secret.aws_access_key_id,
secret.aws_secret_access_key,
secret.aws_session_token,
)
else:
return self.config.key, self.config.secret, self.config.token
@property
def filesystem(self) -> s3fs.S3FileSystem:
"""The s3 filesystem to access this artifact store.
Returns:
The s3 filesystem.
"""
if not self._filesystem:
key, secret, token = self._get_credentials()
self._filesystem = s3fs.S3FileSystem(
key=key,
secret=secret,
token=token,
client_kwargs=self.config.client_kwargs,
config_kwargs=self.config.config_kwargs,
s3_additional_kwargs=self.config.s3_additional_kwargs,
)
return self._filesystem
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of paths that are files in the given directory.
"""
# remove s3 prefix if given, so we can remove the directory later as
# this method is expected to only return filenames
path = convert_to_str(path)
if path.startswith("s3://"):
path = path[5:]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the S3 filesystem.
Args:
file_dict: A file info dict returned by the S3 filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["Key"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
# s3fs.listdir also returns the root directory, so we filter
# it out here
if _extract_basename(dict_)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
A dictionary containing the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for directory, subdirectories, files in self.filesystem.walk(path=top):
yield f"s3://{directory}", subdirectories, files
config: S3ArtifactStoreConfig
property
readonly
Get the config of this artifact store.
Returns:
Type | Description |
---|---|
S3ArtifactStoreConfig |
The config of this artifact store. |
filesystem: S3FileSystem
property
readonly
The s3 filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
S3FileSystem |
The s3 filesystem. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that are files in the given directory. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of paths that are files in the given directory.
"""
# remove s3 prefix if given, so we can remove the directory later as
# this method is expected to only return filenames
path = convert_to_str(path)
if path.startswith("s3://"):
path = path[5:]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by the S3 filesystem.
Args:
file_dict: A file info dict returned by the S3 filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["Key"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
# s3fs.listdir also returns the root directory, so we filter
# it out here
if _extract_basename(dict_)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the file to remove. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the stat info. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
A dictionary containing the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for directory, subdirectories, files in self.filesystem.walk(path=top):
yield f"s3://{directory}", subdirectories, files
flavors
special
Amazon S3 integration flavors.
s3_artifact_store_flavor
Amazon S3 artifact store flavor.
S3ArtifactStoreConfig (BaseArtifactStoreConfig, AuthenticationConfigMixin)
pydantic-model
Configuration for the S3 Artifact Store.
All attributes of this class except path
will be passed to the
s3fs.S3FileSystem
initialization. See
here for more information on how
to use those configuration options to connect to any S3-compatible storage.
When you want to register an S3ArtifactStore from the CLI and need to pass
client_kwargs
, config_kwargs
or s3_additional_kwargs
, you should pass
them as a json string:
zenml artifact-store register my_s3_store --flavor=s3 --path=s3://my_bucket --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
Source code in zenml/integrations/s3/flavors/s3_artifact_store_flavor.py
class S3ArtifactStoreConfig(BaseArtifactStoreConfig, AuthenticationConfigMixin):
"""Configuration for the S3 Artifact Store.
All attributes of this class except `path` will be passed to the
`s3fs.S3FileSystem` initialization. See
[here](https://s3fs.readthedocs.io/en/latest/) for more information on how
to use those configuration options to connect to any S3-compatible storage.
When you want to register an S3ArtifactStore from the CLI and need to pass
`client_kwargs`, `config_kwargs` or `s3_additional_kwargs`, you should pass
them as a json string:
```
zenml artifact-store register my_s3_store --flavor=s3 \
--path=s3://my_bucket --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
```
"""
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"s3://"}
key: Optional[str] = SecretField()
secret: Optional[str] = SecretField()
token: Optional[str] = SecretField()
client_kwargs: Optional[Dict[str, Any]] = None
config_kwargs: Optional[Dict[str, Any]] = None
s3_additional_kwargs: Optional[Dict[str, Any]] = None
@validator(
"client_kwargs", "config_kwargs", "s3_additional_kwargs", pre=True
)
def _convert_json_string(
cls, value: Union[None, str, Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Converts potential JSON strings passed via the CLI to dictionaries.
Args:
value: The value to convert.
Returns:
The converted value.
Raises:
TypeError: If the value is not a `str`, `Dict` or `None`.
ValueError: If the value is an invalid json string or a json string
that does not decode into a dictionary.
"""
if isinstance(value, str):
try:
dict_ = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid json string '{value}'") from e
if not isinstance(dict_, Dict):
raise ValueError(
f"Json string '{value}' did not decode into a dictionary."
)
return dict_
elif isinstance(value, Dict) or value is None:
return value
else:
raise TypeError(f"{value} is not a json string or a dictionary.")
S3ArtifactStoreFlavor (BaseArtifactStoreFlavor)
Flavor of the S3 artifact store.
Source code in zenml/integrations/s3/flavors/s3_artifact_store_flavor.py
class S3ArtifactStoreFlavor(BaseArtifactStoreFlavor):
"""Flavor of the S3 artifact store."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return S3_ARTIFACT_STORE_FLAVOR
@property
def config_class(self) -> Type[S3ArtifactStoreConfig]:
"""The config class of the flavor.
Returns:
The config class of the flavor.
"""
return S3ArtifactStoreConfig
@property
def implementation_class(self) -> Type["S3ArtifactStore"]:
"""Implementation class for this flavor.
Returns:
The implementation class for this flavor.
"""
from zenml.integrations.s3.artifact_stores import S3ArtifactStore
return S3ArtifactStore
config_class: Type[zenml.integrations.s3.flavors.s3_artifact_store_flavor.S3ArtifactStoreConfig]
property
readonly
The config class of the flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.s3.flavors.s3_artifact_store_flavor.S3ArtifactStoreConfig] |
The config class of the flavor. |
implementation_class: Type[S3ArtifactStore]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[S3ArtifactStore] |
The implementation class for this flavor. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
scipy
special
Initialization of the Scipy integration.
ScipyIntegration (Integration)
Definition of scipy integration for ZenML.
Source code in zenml/integrations/scipy/__init__.py
class ScipyIntegration(Integration):
"""Definition of scipy integration for ZenML."""
NAME = SCIPY
REQUIREMENTS = ["scipy"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.scipy import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/scipy/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.scipy import materializers # noqa
materializers
special
Initialization of the Scipy materializers.
sparse_materializer
Implementation of the Scipy Sparse Materializer.
SparseMaterializer (BaseMaterializer)
Materializer to read and write scipy sparse matrices.
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
class SparseMaterializer(BaseMaterializer):
"""Materializer to read and write scipy sparse matrices."""
ASSOCIATED_TYPES = (spmatrix,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> spmatrix:
"""Reads spmatrix from npz file.
Args:
data_type: The type of the spmatrix to load.
Returns:
A spmatrix object.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
) as f:
mat = load_npz(f)
return mat
def handle_return(self, mat: spmatrix) -> None:
"""Writes a spmatrix to the artifact store as a npz file.
Args:
mat: The spmatrix to write.
"""
super().handle_return(mat)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
) as f:
save_npz(f, mat)
handle_input(self, data_type)
Reads spmatrix from npz file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the spmatrix to load. |
required |
Returns:
Type | Description |
---|---|
spmatrix |
A spmatrix object. |
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_input(self, data_type: Type[Any]) -> spmatrix:
"""Reads spmatrix from npz file.
Args:
data_type: The type of the spmatrix to load.
Returns:
A spmatrix object.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
) as f:
mat = load_npz(f)
return mat
handle_return(self, mat)
Writes a spmatrix to the artifact store as a npz file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mat |
spmatrix |
The spmatrix to write. |
required |
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_return(self, mat: spmatrix) -> None:
"""Writes a spmatrix to the artifact store as a npz file.
Args:
mat: The spmatrix to write.
"""
super().handle_return(mat)
with fileio.open(
os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
) as f:
save_npz(f, mat)
seldon
special
Initialization of the Seldon integration.
The Seldon Core integration allows you to use the Seldon Core model serving platform to implement continuous model deployment.
SeldonIntegration (Integration)
Definition of Seldon Core integration for ZenML.
Source code in zenml/integrations/seldon/__init__.py
class SeldonIntegration(Integration):
"""Definition of Seldon Core integration for ZenML."""
NAME = SELDON
REQUIREMENTS = [
"kubernetes==18.20.0",
"seldon-core==1.14.1",
]
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.seldon import secret_schemas # noqa
from zenml.integrations.seldon import services # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Seldon Core.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.seldon.flavors import SeldonModelDeployerFlavor
return [SeldonModelDeployerFlavor]
activate()
classmethod
Activate the Seldon Core integration.
Source code in zenml/integrations/seldon/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the Seldon Core integration."""
from zenml.integrations.seldon import secret_schemas # noqa
from zenml.integrations.seldon import services # noqa
flavors()
classmethod
Declare the stack component flavors for the Seldon Core.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/seldon/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Seldon Core.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.seldon.flavors import SeldonModelDeployerFlavor
return [SeldonModelDeployerFlavor]
constants
Seldon constants.
custom_deployer
special
Initialization of ZenML custom deployer.
zenml_custom_model
Implements a custom model for the Seldon integration.
ZenMLCustomModel
Custom model class for ZenML and Seldon.
This class is used to implement a custom model for the Seldon Core integration, which is used as the main entry point for custom code execution.
Attributes:
Name | Type | Description |
---|---|---|
name |
The name of the model. |
|
model_uri |
The URI of the model. |
|
predict_func |
The predict function of the model. |
Source code in zenml/integrations/seldon/custom_deployer/zenml_custom_model.py
class ZenMLCustomModel:
"""Custom model class for ZenML and Seldon.
This class is used to implement a custom model for the Seldon Core integration,
which is used as the main entry point for custom code execution.
Attributes:
name: The name of the model.
model_uri: The URI of the model.
predict_func: The predict function of the model.
"""
def __init__(
self,
model_name: str,
model_uri: str,
predict_func: str,
):
"""Initializes a ZenMLCustomModel object.
Args:
model_name: The name of the model.
model_uri: The URI of the model.
predict_func: The predict function of the model.
"""
self.name = model_name
self.model_uri = model_uri
self.predict_func = import_class_by_path(predict_func)
self.model = None
self.ready = False
def load(self) -> bool:
"""Load the model.
This function loads the model into memory and sets the ready flag to True.
The model is loaded using the materializer, by saving the information of
the artifact to a file at the preparing time and loading it again at the
prediction time by the materializer.
Returns:
True if the model was loaded successfully, False otherwise.
"""
try:
from zenml.utils.materializer_utils import load_model_from_metadata
self.model = load_model_from_metadata(self.model_uri)
except Exception as e:
logger.error("Failed to load model: {}".format(e))
return False
self.ready = True
return self.ready
def predict(
self,
X: Array_Like,
features_names: Optional[List[str]],
**kwargs: Any,
) -> Array_Like:
"""Predict the given request.
The main predict function of the model. This function is called by the
Seldon Core server when a request is received. Then inside this function,
the user-defined predict function is called.
Args:
X: The request to predict in a dictionary.
features_names: The names of the features.
**kwargs: Additional arguments.
Returns:
The prediction dictionary.
Raises:
Exception: If function could not be called.
NotImplementedError: If the model is not ready.
TypeError: If the request is not a dictionary.
"""
if self.predict_func is not None:
try:
prediction = {"predictions": self.predict_func(self.model, X)}
except Exception as e:
raise Exception("Failed to predict: {}".format(e))
if isinstance(prediction, dict):
return prediction
else:
raise TypeError(
f"Prediction is not a dictionary. Expected dict type but got {type(prediction)}"
)
else:
raise NotImplementedError("Predict function is not implemented")
__init__(self, model_name, model_uri, predict_func)
special
Initializes a ZenMLCustomModel object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
The name of the model. |
required |
model_uri |
str |
The URI of the model. |
required |
predict_func |
str |
The predict function of the model. |
required |
Source code in zenml/integrations/seldon/custom_deployer/zenml_custom_model.py
def __init__(
self,
model_name: str,
model_uri: str,
predict_func: str,
):
"""Initializes a ZenMLCustomModel object.
Args:
model_name: The name of the model.
model_uri: The URI of the model.
predict_func: The predict function of the model.
"""
self.name = model_name
self.model_uri = model_uri
self.predict_func = import_class_by_path(predict_func)
self.model = None
self.ready = False
load(self)
Load the model.
This function loads the model into memory and sets the ready flag to True. The model is loaded using the materializer, by saving the information of the artifact to a file at the preparing time and loading it again at the prediction time by the materializer.
Returns:
Type | Description |
---|---|
bool |
True if the model was loaded successfully, False otherwise. |
Source code in zenml/integrations/seldon/custom_deployer/zenml_custom_model.py
def load(self) -> bool:
"""Load the model.
This function loads the model into memory and sets the ready flag to True.
The model is loaded using the materializer, by saving the information of
the artifact to a file at the preparing time and loading it again at the
prediction time by the materializer.
Returns:
True if the model was loaded successfully, False otherwise.
"""
try:
from zenml.utils.materializer_utils import load_model_from_metadata
self.model = load_model_from_metadata(self.model_uri)
except Exception as e:
logger.error("Failed to load model: {}".format(e))
return False
self.ready = True
return self.ready
predict(self, X, features_names, **kwargs)
Predict the given request.
The main predict function of the model. This function is called by the Seldon Core server when a request is received. Then inside this function, the user-defined predict function is called.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
Union[numpy.ndarray, List[Any], str, bytes, Dict[str, Any]] |
The request to predict in a dictionary. |
required |
features_names |
Optional[List[str]] |
The names of the features. |
required |
**kwargs |
Any |
Additional arguments. |
{} |
Returns:
Type | Description |
---|---|
Union[numpy.ndarray, List[Any], str, bytes, Dict[str, Any]] |
The prediction dictionary. |
Exceptions:
Type | Description |
---|---|
Exception |
If function could not be called. |
NotImplementedError |
If the model is not ready. |
TypeError |
If the request is not a dictionary. |
Source code in zenml/integrations/seldon/custom_deployer/zenml_custom_model.py
def predict(
self,
X: Array_Like,
features_names: Optional[List[str]],
**kwargs: Any,
) -> Array_Like:
"""Predict the given request.
The main predict function of the model. This function is called by the
Seldon Core server when a request is received. Then inside this function,
the user-defined predict function is called.
Args:
X: The request to predict in a dictionary.
features_names: The names of the features.
**kwargs: Additional arguments.
Returns:
The prediction dictionary.
Raises:
Exception: If function could not be called.
NotImplementedError: If the model is not ready.
TypeError: If the request is not a dictionary.
"""
if self.predict_func is not None:
try:
prediction = {"predictions": self.predict_func(self.model, X)}
except Exception as e:
raise Exception("Failed to predict: {}".format(e))
if isinstance(prediction, dict):
return prediction
else:
raise TypeError(
f"Prediction is not a dictionary. Expected dict type but got {type(prediction)}"
)
else:
raise NotImplementedError("Predict function is not implemented")
flavors
special
Seldon integration flavors.
seldon_model_deployer_flavor
Seldon model deployer flavor.
SeldonModelDeployerConfig (BaseModelDeployerConfig)
pydantic-model
Config for the Seldon Model Deployer.
Attributes:
Name | Type | Description |
---|---|---|
kubernetes_context |
Optional[str] |
the Kubernetes context to use to contact the remote Seldon Core installation. If not specified, the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod). |
kubernetes_namespace |
Optional[str] |
the Kubernetes namespace where the Seldon Core deployment servers are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod). |
base_url |
str |
the base URL of the Kubernetes ingress used to expose the Seldon Core deployment servers. |
secret |
Optional[str] |
the name of a ZenML secret containing the credentials used by Seldon Core storage initializers to authenticate to the Artifact Store (i.e. the storage backend where models are stored - see https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials). |
Source code in zenml/integrations/seldon/flavors/seldon_model_deployer_flavor.py
class SeldonModelDeployerConfig(BaseModelDeployerConfig):
"""Config for the Seldon Model Deployer.
Attributes:
kubernetes_context: the Kubernetes context to use to contact the remote
Seldon Core installation. If not specified, the current
configuration is used. Depending on where the Seldon model deployer
is being used, this can be either a locally active context or an
in-cluster Kubernetes configuration (if running inside a pod).
kubernetes_namespace: the Kubernetes namespace where the Seldon Core
deployment servers are provisioned and managed by ZenML. If not
specified, the namespace set in the current configuration is used.
Depending on where the Seldon model deployer is being used, this can
be either the current namespace configured in the locally active
context or the namespace in the context of which the pod is running
(if running inside a pod).
base_url: the base URL of the Kubernetes ingress used to expose the
Seldon Core deployment servers.
secret: the name of a ZenML secret containing the credentials used by
Seldon Core storage initializers to authenticate to the Artifact
Store (i.e. the storage backend where models are stored - see
https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials).
"""
kubernetes_context: Optional[str]
kubernetes_namespace: Optional[str]
base_url: str # TODO: unused?
secret: Optional[str]
SeldonModelDeployerFlavor (BaseModelDeployerFlavor)
Seldon Core model deployer flavor.
Source code in zenml/integrations/seldon/flavors/seldon_model_deployer_flavor.py
class SeldonModelDeployerFlavor(BaseModelDeployerFlavor):
"""Seldon Core model deployer flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return SELDON_MODEL_DEPLOYER_FLAVOR
@property
def config_class(self) -> Type[SeldonModelDeployerConfig]:
"""Returns `SeldonModelDeployerConfig` config class.
Returns:
The config class.
"""
return SeldonModelDeployerConfig
@property
def implementation_class(self) -> Type["SeldonModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.seldon.model_deployers import (
SeldonModelDeployer,
)
return SeldonModelDeployer
config_class: Type[zenml.integrations.seldon.flavors.seldon_model_deployer_flavor.SeldonModelDeployerConfig]
property
readonly
Returns SeldonModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.seldon.flavors.seldon_model_deployer_flavor.SeldonModelDeployerConfig] |
The config class. |
implementation_class: Type[SeldonModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[SeldonModelDeployer] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
model_deployers
special
Initialization of the Seldon Model Deployer.
seldon_model_deployer
Implementation of the Seldon Model Deployer.
SeldonModelDeployer (BaseModelDeployer)
Seldon Core model deployer stack component implementation.
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
class SeldonModelDeployer(BaseModelDeployer):
"""Seldon Core model deployer stack component implementation."""
_client: Optional[SeldonClient] = None
@property
def config(self) -> SeldonModelDeployerConfig:
"""Returns the `SeldonModelDeployerConfig` config.
Returns:
The configuration.
"""
return cast(SeldonModelDeployerConfig, self._config)
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "SeldonDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information that might be relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
Model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
}
@staticmethod
def get_active_model_deployer() -> "SeldonModelDeployer":
"""Get the Seldon Core model deployer registered in the active stack.
Returns:
The Seldon Core model deployer registered in the active stack.
Raises:
TypeError: if the Seldon Core model deployer is not available.
"""
model_deployer = Client( # type: ignore [call-arg]
skip_client_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, SeldonModelDeployer
):
raise TypeError(
f"The active stack needs to have a Seldon Core model deployer "
f"component registered to be able to deploy models with Seldon "
f"Core. You can create a new stack with a Seldon Core model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d seldon ...'\n"
)
return model_deployer
@property
def seldon_client(self) -> SeldonClient:
"""Get the Seldon Core client associated with this model deployer.
Returns:
The Seldon Core client.
"""
if not self._client:
self._client = SeldonClient(
context=self.config.kubernetes_context,
namespace=self.config.kubernetes_namespace,
)
return self._client
@property
def kubernetes_secret_name(self) -> Optional[str]:
"""Get the Kubernetes secret name associated with this model deployer.
If a secret is configured for this model deployer, a corresponding
Kubernetes secret is created in the remote cluster to be used
by Seldon Core storage initializers to authenticate to the Artifact
Store. This method returns the unique name that is used for this secret.
Returns:
The Seldon Core Kubernetes secret name, or None if no secret is
configured.
"""
if not self.config.secret:
return None
return (
re.sub(
r"[^0-9a-zA-Z-]+",
"-",
f"zenml-seldon-core-{self.config.secret}",
)
.strip("-")
.lower()
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(SELDON_DOCKER_IMAGE_KEY, repo_digest)
def _create_or_update_kubernetes_secret(self) -> Optional[str]:
"""Create or update a Kubernetes secret.
Uses the information stored in the ZenML secret configured for the model deployer.
Returns:
The name of the Kubernetes secret that was created or updated, or
None if no secret was configured.
Raises:
RuntimeError: if the secret cannot be created or updated.
"""
# if a ZenML secret was configured in the model deployer,
# create a Kubernetes secret as a means to pass this information
# to the Seldon Core deployment
if self.config.secret:
secret_manager = Client( # type: ignore [call-arg]
skip_client_check=True
).active_stack.secrets_manager
if not secret_manager or not isinstance(
secret_manager, BaseSecretsManager
):
raise RuntimeError(
f"The active stack doesn't have a secret manager component. "
f"The ZenML secret specified in the Seldon Core Model "
f"Deployer configuration cannot be fetched: {self.config.secret}."
)
try:
zenml_secret = secret_manager.get_secret(self.config.secret)
except KeyError:
raise RuntimeError(
f"The ZenML secret '{self.config.secret}' specified in the "
f"Seldon Core Model Deployer configuration was not found "
f"in the active stack's secret manager."
)
# should never happen, just making mypy happy
assert self.kubernetes_secret_name is not None
self.seldon_client.create_or_update_secret(
self.kubernetes_secret_name, zenml_secret
)
return self.kubernetes_secret_name
def _delete_kubernetes_secret(self) -> None:
"""Delete the Kubernetes secret associated with this model deployer.
Do this if no Seldon Core deployments are using it.
"""
if self.kubernetes_secret_name:
# fetch all the Seldon Core deployments that currently
# configured to use this secret
services = self.find_model_server()
for service in services:
config = cast(SeldonDeploymentConfig, service.config)
if config.secret_name == self.kubernetes_secret_name:
return
self.seldon_client.delete_secret(self.kubernetes_secret_name)
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new Seldon Core deployment or update an existing one.
# noqa: DAR402
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new Seldon
Core deployment server to reflect the model and other configuration
parameters specified in the supplied Seldon deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing Seldon Core deployment that is *equivalent* to the supplied
configuration parameters. Two or more Seldon Core deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two Seldon Core deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent Seldon Core deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing Seldon Core deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new Seldon Core deployment
server for each new model version. If multiple equivalent Seldon Core
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with Seldon.
Core
replace: set this flag to True to find and update an equivalent
Seldon Core deployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the Seldon Core server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the Seldon Core
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML Seldon Core deployment service object that can be used to
interact with the remote Seldon Core server.
Raises:
SeldonClientError: if a Seldon Core client error is encountered
while provisioning the Seldon Core deployment server.
RuntimeError: if `timeout` is set to a positive value that is
exceeded while waiting for the Seldon Core deployment server
to start, or if an operational failure is encountered before
it reaches a ready state.
"""
config = cast(SeldonDeploymentConfig, config)
service = None
# if a custom Kubernetes secret is not explicitly specified in the
# SeldonDeploymentConfig, try to create one from the ZenML secret
# configured for the model deployer
config.secret_name = (
config.secret_name or self._create_or_update_kubernetes_secret()
)
# if replace is True, find equivalent Seldon Core deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing Seldon deployment service: {service}"
)
else:
# create a new service
service = SeldonDeploymentService(config=config)
logger.info(f"Creating a new Seldon deployment service: {service}")
# start the service which in turn provisions the Seldon Core
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
# Add telemetry with metadata that gets the stack metadata and
# differentiates between pure model and custom code deployments
stack = Client().active_stack
stack_metadata = {
component_type.value: component.flavor
for component_type, component in stack.components.items()
}
metadata = {
"store_type": Client().zen_store.type.value,
**stack_metadata,
"is_custom_code_deployment": config.is_custom_deployment,
}
track_event(AnalyticsEvent.MODEL_DEPLOYED, metadata=metadata)
return service
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more Seldon Core model services that match the given criteria.
The Seldon Core deployment services that meet the search criteria are
returned sorted in descending order of their creation time (i.e. more
recent deployments first).
Args:
running: if true, only running services will be returned.
service_uuid: the UUID of the Seldon Core service that was originally used
to create the Seldon Core deployment resource.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
model_type: the Seldon Core server implementation used to serve
the model
Returns:
One or more Seldon Core service objects representing Seldon Core
model servers that match the input search criteria.
"""
# Use a Seldon deployment service configuration to compute the labels
config = SeldonDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_name=model_name or "",
model_uri=model_uri or "",
implementation=model_type or "",
)
labels = config.get_seldon_deployment_labels()
if service_uuid:
# the service UUID is not a label covered by the Seldon
# deployment service configuration, so we need to add it
# separately
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.seldon_client.find_deployments(labels=labels)
# sort the deployments in descending order of their creation time
deployments.sort(
key=lambda deployment: datetime.strptime(
deployment.metadata.creationTimestamp,
"%Y-%m-%dT%H:%M:%SZ",
)
if deployment.metadata.creationTimestamp
else datetime.min,
reverse=True,
)
services: List[BaseService] = []
for deployment in deployments:
# recreate the Seldon deployment service object from the Seldon
# deployment resource
service = SeldonDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a Seldon Core model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping Seldon Core model servers is not
supported.
"""
raise NotImplementedError(
"Stopping Seldon Core model servers is not implemented. Try "
"deleting the Seldon Core model server instead."
)
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting Seldon Core
model servers
"""
raise NotImplementedError(
"Starting Seldon Core model servers is not implemented"
)
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
# if this is the last Seldon Core model server, delete the Kubernetes
# secret used to store the authentication information for the Seldon
# Core model server storage initializer
self._delete_kubernetes_secret()
config: SeldonModelDeployerConfig
property
readonly
Returns the SeldonModelDeployerConfig
config.
Returns:
Type | Description |
---|---|
SeldonModelDeployerConfig |
The configuration. |
kubernetes_secret_name: Optional[str]
property
readonly
Get the Kubernetes secret name associated with this model deployer.
If a secret is configured for this model deployer, a corresponding Kubernetes secret is created in the remote cluster to be used by Seldon Core storage initializers to authenticate to the Artifact Store. This method returns the unique name that is used for this secret.
Returns:
Type | Description |
---|---|
Optional[str] |
The Seldon Core Kubernetes secret name, or None if no secret is configured. |
seldon_client: SeldonClient
property
readonly
Get the Seldon Core client associated with this model deployer.
Returns:
Type | Description |
---|---|
SeldonClient |
The Seldon Core client. |
delete_model_server(self, uuid, timeout=300, force=False)
Delete a Seldon Core model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to delete. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def delete_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Delete a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to delete.
timeout: timeout in seconds to wait for the service to stop. If
set to 0, the method will return immediately after
deprovisioning the service, without waiting for it to stop.
force: if True, force the service to stop.
"""
services = self.find_model_server(service_uuid=uuid)
if len(services) == 0:
return
services[0].stop(timeout=timeout, force=force)
# if this is the last Seldon Core model server, delete the Kubernetes
# secret used to store the authentication information for the Seldon
# Core model server storage initializer
self._delete_kubernetes_secret()
deploy_model(self, config, replace=False, timeout=300)
Create a new Seldon Core deployment or update an existing one.
noqa: DAR402
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new Seldon Core deployment server to reflect the model and other configuration parameters specified in the supplied Seldon deploymentconfig
. -
if
replace
is True, this method will first attempt to find an existing Seldon Core deployment that is equivalent to the supplied configuration parameters. Two or more Seldon Core deployments are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two Seldon Core deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent Seldon Core deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing Seldon Core deployment to retain its prediction URL while performing a rolling update to serve a new model version.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new Seldon Core deployment
server for each new model version. If multiple equivalent Seldon Core
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServiceConfig |
the configuration of the model to be deployed with Seldon. Core |
required |
replace |
bool |
set this flag to True to find and update an equivalent Seldon Core deployment server with the new model instead of starting a new deployment server. |
False |
timeout |
int |
the timeout in seconds to wait for the Seldon Core server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Seldon Core server is provisioned, without waiting for it to fully start. |
300 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML Seldon Core deployment service object that can be used to interact with the remote Seldon Core server. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if a Seldon Core client error is encountered while provisioning the Seldon Core deployment server. |
RuntimeError |
if |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def deploy_model(
self,
config: ServiceConfig,
replace: bool = False,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new Seldon Core deployment or update an existing one.
# noqa: DAR402
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new Seldon
Core deployment server to reflect the model and other configuration
parameters specified in the supplied Seldon deployment `config`.
* if `replace` is True, this method will first attempt to find an
existing Seldon Core deployment that is *equivalent* to the supplied
configuration parameters. Two or more Seldon Core deployments are
considered equivalent if they have the same `pipeline_name`,
`pipeline_step_name` and `model_name` configuration parameters. To
put it differently, two Seldon Core deployments are equivalent if
they serve versions of the same model deployed by the same pipeline
step. If an equivalent Seldon Core deployment is found, it will be
updated in place to reflect the new configuration parameters. This
allows an existing Seldon Core deployment to retain its prediction
URL while performing a rolling update to serve a new model version.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new Seldon Core deployment
server for each new model version. If multiple equivalent Seldon Core
deployments are found, the most recently created deployment is selected
to be updated and the others are deleted.
Args:
config: the configuration of the model to be deployed with Seldon.
Core
replace: set this flag to True to find and update an equivalent
Seldon Core deployment server with the new model instead of
starting a new deployment server.
timeout: the timeout in seconds to wait for the Seldon Core server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the Seldon Core
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML Seldon Core deployment service object that can be used to
interact with the remote Seldon Core server.
Raises:
SeldonClientError: if a Seldon Core client error is encountered
while provisioning the Seldon Core deployment server.
RuntimeError: if `timeout` is set to a positive value that is
exceeded while waiting for the Seldon Core deployment server
to start, or if an operational failure is encountered before
it reaches a ready state.
"""
config = cast(SeldonDeploymentConfig, config)
service = None
# if a custom Kubernetes secret is not explicitly specified in the
# SeldonDeploymentConfig, try to create one from the ZenML secret
# configured for the model deployer
config.secret_name = (
config.secret_name or self._create_or_update_kubernetes_secret()
)
# if replace is True, find equivalent Seldon Core deployments
if replace is True:
equivalent_services = self.find_model_server(
running=False,
pipeline_name=config.pipeline_name,
pipeline_step_name=config.pipeline_step_name,
model_name=config.model_name,
)
for equivalent_service in equivalent_services:
if service is None:
# keep the most recently created service
service = equivalent_service
else:
try:
# delete the older services and don't wait for them to
# be deprovisioned
service.stop()
except RuntimeError:
# ignore errors encountered while stopping old services
pass
if service:
# update an equivalent service in place
service.update(config)
logger.info(
f"Updating an existing Seldon deployment service: {service}"
)
else:
# create a new service
service = SeldonDeploymentService(config=config)
logger.info(f"Creating a new Seldon deployment service: {service}")
# start the service which in turn provisions the Seldon Core
# deployment server and waits for it to reach a ready state
service.start(timeout=timeout)
# Add telemetry with metadata that gets the stack metadata and
# differentiates between pure model and custom code deployments
stack = Client().active_stack
stack_metadata = {
component_type.value: component.flavor
for component_type, component in stack.components.items()
}
metadata = {
"store_type": Client().zen_store.type.value,
**stack_metadata,
"is_custom_code_deployment": config.is_custom_deployment,
}
track_event(AnalyticsEvent.MODEL_DEPLOYED, metadata=metadata)
return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)
Find one or more Seldon Core model services that match the given criteria.
The Seldon Core deployment services that meet the search criteria are returned sorted in descending order of their creation time (i.e. more recent deployments first).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running |
bool |
if true, only running services will be returned. |
False |
service_uuid |
Optional[uuid.UUID] |
the UUID of the Seldon Core service that was originally used to create the Seldon Core deployment resource. |
None |
pipeline_name |
Optional[str] |
name of the pipeline that the deployed model was part of. |
None |
pipeline_run_id |
Optional[str] |
ID of the pipeline run which the deployed model was part of. |
None |
pipeline_step_name |
Optional[str] |
the name of the pipeline model deployment step that deployed the model. |
None |
model_name |
Optional[str] |
the name of the deployed model. |
None |
model_uri |
Optional[str] |
URI of the deployed model. |
None |
model_type |
Optional[str] |
the Seldon Core server implementation used to serve the model |
None |
Returns:
Type | Description |
---|---|
List[zenml.services.service.BaseService] |
One or more Seldon Core service objects representing Seldon Core model servers that match the input search criteria. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def find_model_server(
self,
running: bool = False,
service_uuid: Optional[UUID] = None,
pipeline_name: Optional[str] = None,
pipeline_run_id: Optional[str] = None,
pipeline_step_name: Optional[str] = None,
model_name: Optional[str] = None,
model_uri: Optional[str] = None,
model_type: Optional[str] = None,
) -> List[BaseService]:
"""Find one or more Seldon Core model services that match the given criteria.
The Seldon Core deployment services that meet the search criteria are
returned sorted in descending order of their creation time (i.e. more
recent deployments first).
Args:
running: if true, only running services will be returned.
service_uuid: the UUID of the Seldon Core service that was originally used
to create the Seldon Core deployment resource.
pipeline_name: name of the pipeline that the deployed model was part
of.
pipeline_run_id: ID of the pipeline run which the deployed model was
part of.
pipeline_step_name: the name of the pipeline model deployment step
that deployed the model.
model_name: the name of the deployed model.
model_uri: URI of the deployed model.
model_type: the Seldon Core server implementation used to serve
the model
Returns:
One or more Seldon Core service objects representing Seldon Core
model servers that match the input search criteria.
"""
# Use a Seldon deployment service configuration to compute the labels
config = SeldonDeploymentConfig(
pipeline_name=pipeline_name or "",
pipeline_run_id=pipeline_run_id or "",
pipeline_step_name=pipeline_step_name or "",
model_name=model_name or "",
model_uri=model_uri or "",
implementation=model_type or "",
)
labels = config.get_seldon_deployment_labels()
if service_uuid:
# the service UUID is not a label covered by the Seldon
# deployment service configuration, so we need to add it
# separately
labels["zenml.service_uuid"] = str(service_uuid)
deployments = self.seldon_client.find_deployments(labels=labels)
# sort the deployments in descending order of their creation time
deployments.sort(
key=lambda deployment: datetime.strptime(
deployment.metadata.creationTimestamp,
"%Y-%m-%dT%H:%M:%SZ",
)
if deployment.metadata.creationTimestamp
else datetime.min,
reverse=True,
)
services: List[BaseService] = []
for deployment in deployments:
# recreate the Seldon deployment service object from the Seldon
# deployment resource
service = SeldonDeploymentService.create_from_deployment(
deployment=deployment
)
if running and not service.is_running:
# skip non-running services
continue
services.append(service)
return services
get_active_model_deployer()
staticmethod
Get the Seldon Core model deployer registered in the active stack.
Returns:
Type | Description |
---|---|
SeldonModelDeployer |
The Seldon Core model deployer registered in the active stack. |
Exceptions:
Type | Description |
---|---|
TypeError |
if the Seldon Core model deployer is not available. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "SeldonModelDeployer":
"""Get the Seldon Core model deployer registered in the active stack.
Returns:
The Seldon Core model deployer registered in the active stack.
Raises:
TypeError: if the Seldon Core model deployer is not available.
"""
model_deployer = Client( # type: ignore [call-arg]
skip_client_check=True
).active_stack.model_deployer
if not model_deployer or not isinstance(
model_deployer, SeldonModelDeployer
):
raise TypeError(
f"The active stack needs to have a Seldon Core model deployer "
f"component registered to be able to deploy models with Seldon "
f"Core. You can create a new stack with a Seldon Core model "
f"deployer component or update your existing stack to add this "
f"component, e.g.:\n\n"
f" 'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
f"--kubernetes_context=context-name --kubernetes_namespace="
f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
f" 'zenml stack create stack-name -d seldon ...'\n"
)
return model_deployer
get_model_server_info(service_instance)
staticmethod
Return implementation specific information that might be relevant to the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
SeldonDeploymentService |
Instance of a SeldonDeploymentService |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
Model server information. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "SeldonDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information that might be relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
Model server information.
"""
return {
"PREDICTION_URL": service_instance.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
}
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(SELDON_DOCKER_IMAGE_KEY, repo_digest)
start_model_server(self, uuid, timeout=300)
Start a Seldon Core model deployment server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to start. |
required |
timeout |
int |
timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active. |
300 |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
since we don't support starting Seldon Core model servers |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def start_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
"""Start a Seldon Core model deployment server.
Args:
uuid: UUID of the model server to start.
timeout: timeout in seconds to wait for the service to become
active. . If set to 0, the method will return immediately after
provisioning the service, without waiting for it to become
active.
Raises:
NotImplementedError: since we don't support starting Seldon Core
model servers
"""
raise NotImplementedError(
"Starting Seldon Core model servers is not implemented"
)
stop_model_server(self, uuid, timeout=300, force=False)
Stop a Seldon Core model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
UUID of the model server to stop. |
required |
timeout |
int |
timeout in seconds to wait for the service to stop. |
300 |
force |
bool |
if True, force the service to stop. |
False |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
stopping Seldon Core model servers is not supported. |
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def stop_model_server(
self,
uuid: UUID,
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Stop a Seldon Core model server.
Args:
uuid: UUID of the model server to stop.
timeout: timeout in seconds to wait for the service to stop.
force: if True, force the service to stop.
Raises:
NotImplementedError: stopping Seldon Core model servers is not
supported.
"""
raise NotImplementedError(
"Stopping Seldon Core model servers is not implemented. Try "
"deleting the Seldon Core model server instead."
)
secret_schemas
special
Initialization for the Seldon secret schemas.
These are secret schemas that can be used to authenticate Seldon to the Artifact Store used to store served ML models.
secret_schemas
Implementation for Seldon secret schemas.
SeldonAzureSecretSchema (BaseSecretSchema)
pydantic-model
Seldon Azure Blob Storage credentials.
Based on: https://rclone.org/azureblob/
Attributes:
Name | Type | Description |
---|---|---|
rclone_config_azureblob_type |
Literal['azureblob'] |
the rclone config type. Must be set to "azureblob" for this schema. |
rclone_config_azureblob_account |
Optional[str] |
storage Account Name. Leave blank to use SAS URL or MSI. |
rclone_config_azureblob_key |
Optional[str] |
storage Account Key. Leave blank to use SAS URL or MSI. |
rclone_config_azureblob_sas_url |
Optional[str] |
SAS URL for container level access only. Leave blank if using account/key or MSI. |
rclone_config_azureblob_use_msi |
bool |
use a managed service identity to authenticate (only works in Azure). |
Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonAzureSecretSchema(BaseSecretSchema):
"""Seldon Azure Blob Storage credentials.
Based on: https://rclone.org/azureblob/
Attributes:
rclone_config_azureblob_type: the rclone config type. Must be set to
"azureblob" for this schema.
rclone_config_azureblob_account: storage Account Name. Leave blank to
use SAS URL or MSI.
rclone_config_azureblob_key: storage Account Key. Leave blank to
use SAS URL or MSI.
rclone_config_azureblob_sas_url: SAS URL for container level access
only. Leave blank if using account/key or MSI.
rclone_config_azureblob_use_msi: use a managed service identity to
authenticate (only works in Azure).
"""
TYPE: ClassVar[str] = SELDON_AZUREBLOB_SECRET_SCHEMA_TYPE
rclone_config_azureblob_type: Literal["azureblob"] = "azureblob"
rclone_config_azureblob_account: Optional[str]
rclone_config_azureblob_key: Optional[str]
rclone_config_azureblob_sas_url: Optional[str]
rclone_config_azureblob_use_msi: bool = False
SeldonGSSecretSchema (BaseSecretSchema)
pydantic-model
Seldon GCS credentials.
Based on: https://rclone.org/googlecloudstorage/
Attributes:
Name | Type | Description |
---|---|---|
rclone_config_gs_type |
Literal['google cloud storage'] |
the rclone config type. Must be set to "google cloud storage" for this schema. |
rclone_config_gs_client_id |
Optional[str] |
OAuth client id. |
rclone_config_gs_client_secret |
Optional[str] |
OAuth client secret. |
rclone_config_gs_token |
Optional[str] |
OAuth Access Token as a JSON blob. |
rclone_config_gs_project_number |
Optional[str] |
project number. |
rclone_config_gs_service_account_credentials |
Optional[str] |
service account credentials JSON blob. |
rclone_config_gs_anonymous |
bool |
access public buckets and objects without credentials. Set to True if you just want to download files and don't configure credentials. |
rclone_config_gs_auth_url |
Optional[str] |
auth server URL. |
Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonGSSecretSchema(BaseSecretSchema):
"""Seldon GCS credentials.
Based on: https://rclone.org/googlecloudstorage/
Attributes:
rclone_config_gs_type: the rclone config type. Must be set to "google
cloud storage" for this schema.
rclone_config_gs_client_id: OAuth client id.
rclone_config_gs_client_secret: OAuth client secret.
rclone_config_gs_token: OAuth Access Token as a JSON blob.
rclone_config_gs_project_number: project number.
rclone_config_gs_service_account_credentials: service account
credentials JSON blob.
rclone_config_gs_anonymous: access public buckets and objects without
credentials. Set to True if you just want to download files and
don't configure credentials.
rclone_config_gs_auth_url: auth server URL.
"""
TYPE: ClassVar[str] = SELDON_GS_SECRET_SCHEMA_TYPE
rclone_config_gs_type: Literal[
"google cloud storage"
] = "google cloud storage"
rclone_config_gs_client_id: Optional[str]
rclone_config_gs_client_secret: Optional[str]
rclone_config_gs_project_number: Optional[str]
rclone_config_gs_service_account_credentials: Optional[str]
rclone_config_gs_anonymous: bool = False
rclone_config_gs_token: Optional[str]
rclone_config_gs_auth_url: Optional[str]
rclone_config_gs_token_url: Optional[str]
SeldonS3SecretSchema (BaseSecretSchema)
pydantic-model
Seldon S3 credentials.
Based on: https://rclone.org/s3/#amazon-s3
Attributes:
Name | Type | Description |
---|---|---|
rclone_config_s3_type |
Literal['s3'] |
the rclone config type. Must be set to "s3" for this schema. |
rclone_config_s3_provider |
str |
the S3 provider (e.g. aws, ceph, minio). |
rclone_config_s3_env_auth |
bool |
get AWS credentials from EC2/ECS meta data (i.e. with IAM roles configuration). Only applies if access_key_id and secret_access_key are blank. |
rclone_config_s3_access_key_id |
Optional[str] |
AWS Access Key ID. |
rclone_config_s3_secret_access_key |
Optional[str] |
AWS Secret Access Key. |
rclone_config_s3_session_token |
Optional[str] |
AWS Session Token. |
rclone_config_s3_region |
Optional[str] |
region to connect to. |
rclone_config_s3_endpoint |
Optional[str] |
S3 API endpoint. |
Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonS3SecretSchema(BaseSecretSchema):
"""Seldon S3 credentials.
Based on: https://rclone.org/s3/#amazon-s3
Attributes:
rclone_config_s3_type: the rclone config type. Must be set to "s3" for
this schema.
rclone_config_s3_provider: the S3 provider (e.g. aws, ceph, minio).
rclone_config_s3_env_auth: get AWS credentials from EC2/ECS meta data
(i.e. with IAM roles configuration). Only applies if access_key_id
and secret_access_key are blank.
rclone_config_s3_access_key_id: AWS Access Key ID.
rclone_config_s3_secret_access_key: AWS Secret Access Key.
rclone_config_s3_session_token: AWS Session Token.
rclone_config_s3_region: region to connect to.
rclone_config_s3_endpoint: S3 API endpoint.
"""
TYPE: ClassVar[str] = SELDON_S3_SECRET_SCHEMA_TYPE
rclone_config_s3_type: Literal["s3"] = "s3"
rclone_config_s3_provider: str = "aws"
rclone_config_s3_env_auth: bool = False
rclone_config_s3_access_key_id: Optional[str]
rclone_config_s3_secret_access_key: Optional[str]
rclone_config_s3_session_token: Optional[str]
rclone_config_s3_region: Optional[str]
rclone_config_s3_endpoint: Optional[str]
seldon_client
Implementation of the Seldon client for ZenML.
SeldonClient
A client for interacting with Seldon Deployments.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClient:
"""A client for interacting with Seldon Deployments."""
def __init__(self, context: Optional[str], namespace: Optional[str]):
"""Initialize a Seldon Core client.
Args:
context: the Kubernetes context to use.
namespace: the Kubernetes namespace to use.
"""
self._context = context
self._namespace = namespace
self._initialize_k8s_clients()
def _initialize_k8s_clients(self) -> None:
"""Initialize the Kubernetes clients.
Raises:
SeldonClientError: if Kubernetes configuration could not be loaded
"""
try:
k8s_config.load_incluster_config()
if not self._namespace:
# load the namespace in the context of which the
# current pod is running
self._namespace = open(
"/var/run/secrets/kubernetes.io/serviceaccount/namespace"
).read()
except k8s_config.config_exception.ConfigException:
if not self._namespace:
raise SeldonClientError(
"The Kubernetes namespace must be explicitly "
"configured when running outside of a cluster."
)
try:
k8s_config.load_kube_config(
context=self._context, persist_config=False
)
except k8s_config.config_exception.ConfigException as e:
raise SeldonClientError(
"Could not load the Kubernetes configuration"
) from e
self._core_api = k8s_client.CoreV1Api()
self._custom_objects_api = k8s_client.CustomObjectsApi()
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See:
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: the labels to sanitize.
"""
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
@property
def namespace(self) -> str:
"""Returns the Kubernetes namespace in use by the client.
Returns:
The Kubernetes namespace in use by the client.
Raises:
RuntimeError: if the namespace has not been configured.
"""
if not self._namespace:
# shouldn't happen if the client is initialized, but we need to
# appease the mypy type checker
raise RuntimeError("The Kubernetes namespace is not configured")
return self._namespace
def create_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Create a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to create
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the created Seldon Core deployment resource with updated status.
Raises:
SeldonDeploymentExistsError: if a deployment with the same name
already exists.
SeldonClientError: if an unknown error occurs during the creation of
the deployment.
"""
try:
logger.debug(f"Creating SeldonDeployment resource: {deployment}")
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
body_deploy = deployment.dict(exclude_none=True)
response = self._custom_objects_api.create_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
body=body_deploy,
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when creating SeldonDeployment resource: %s", str(e)
)
if e.status == 409:
raise SeldonDeploymentExistsError(
f"A deployment with the name {deployment.name} "
f"already exists in namespace {self._namespace}"
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
created_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and created_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
created_deployment = self.get_deployment(name=deployment.name)
return created_deployment
def delete_deployment(
self,
name: str,
force: bool = False,
poll_timeout: int = 0,
) -> None:
"""Delete a Seldon Core deployment resource managed by ZenML.
Args:
name: the name of the Seldon Core deployment resource to delete.
force: if True, the deployment deletion will be forced (the graceful
period will be set to zero).
poll_timeout: the maximum time to wait for the deployment to be
deleted. If set to 0, the function will return immediately
without checking the deployment status. If a timeout
occurs and the deployment still exists, this method will
return and no exception will be raised.
Raises:
SeldonClientError: if an unknown error occurs during the deployment
removal.
"""
try:
logger.debug(f"Deleting SeldonDeployment resource: {name}")
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=name)
response = self._custom_objects_api.delete_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
_request_timeout=poll_timeout or None,
grace_period_seconds=0 if force else None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when deleting SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting SeldonDeployment resource {name}"
) from e
while poll_timeout > 0:
try:
self.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return
time.sleep(5)
poll_timeout -= 5
def update_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Update a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to update
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the updated Seldon Core deployment resource with updated status.
Raises:
SeldonClientError: if an unknown error occurs while updating the
deployment.
"""
try:
logger.debug(
f"Updating SeldonDeployment resource: {deployment.name}"
)
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=deployment.name)
response = self._custom_objects_api.patch_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=deployment.name,
body=deployment.dict(exclude_none=True),
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when updating SeldonDeployment resource: %s", str(e)
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
updated_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and updated_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
updated_deployment = self.get_deployment(name=deployment.name)
return updated_deployment
def get_deployment(self, name: str) -> SeldonDeployment:
"""Get a ZenML managed Seldon Core deployment resource by name.
Args:
name: the name of the Seldon Core deployment resource to fetch.
Returns:
The Seldon Core deployment resource.
Raises:
SeldonDeploymentNotFoundError: if the deployment resource cannot
be found or is not managed by ZenML.
SeldonClientError: if an unknown error occurs while fetching
the deployment.
"""
try:
logger.debug(f"Retrieving SeldonDeployment resource: {name}")
response = self._custom_objects_api.get_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
)
logger.debug("Seldon Core API response: %s", response)
try:
deployment = SeldonDeployment(**response)
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(response),
)
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource {name} could not be parsed"
)
# Only Seldon deployments managed by ZenML are returned
if not deployment.is_managed_by_zenml():
raise SeldonDeploymentNotFoundError(
f"Seldon Deployment {name} is not managed by ZenML"
)
return deployment
except k8s_client.rest.ApiException as e:
if e.status == 404:
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource not found: {name}"
) from e
logger.error(
"Exception when fetching SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching SeldonDeployment "
f"resource: {name}"
) from e
def find_deployments(
self,
name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
fields: Optional[Dict[str, str]] = None,
) -> List[SeldonDeployment]:
"""Find all ZenML-managed Seldon Core deployment resources matching the given criteria.
Args:
name: optional name of the deployment resource to find.
fields: optional selector to restrict the list of returned
Seldon deployments by their fields. Defaults to everything.
labels: optional selector to restrict the list of returned
Seldon deployments by their labels. Defaults to everything.
Returns:
List of Seldon Core deployments that match the given criteria.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the deployments.
"""
fields = fields or {}
labels = labels or {}
# always filter results to only include Seldon deployments managed
# by ZenML
labels["app"] = "zenml"
if name:
fields = {"metadata.name": name}
field_selector = (
",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
)
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
try:
logger.debug(
f"Searching SeldonDeployment resources with label selector "
f"'{labels or ''}' and field selector '{fields or ''}'"
)
response = self._custom_objects_api.list_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
field_selector=field_selector,
label_selector=label_selector,
)
logger.debug(
"Seldon Core API returned %s items", len(response["items"])
)
deployments = []
for item in response.get("items") or []:
try:
deployments.append(SeldonDeployment(**item))
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(item),
)
return deployments
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when searching SeldonDeployment resources with "
"label selector '%s' and field selector '%s': %s",
label_selector or "",
field_selector or "",
)
raise SeldonClientError(
f"Unexpected exception when searching SeldonDeployment "
f"with labels '{labels or ''}' and field '{fields or ''}'"
) from e
def get_deployment_logs(
self,
name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core deployment resource.
Args:
name: the name of the Seldon Core deployment to get logs for.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Yields:
The next log line.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the logs.
"""
logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
try:
response = self._core_api.list_namespaced_pod(
namespace=self._namespace,
label_selector=f"seldon-deployment-id={name}",
)
logger.debug("Kubernetes API response: %s", response)
pods = response.items
if not pods:
raise SeldonClientError(
f"The Seldon Core deployment {name} is not currently "
f"running: no Kubernetes pods associated with it were found"
)
pod = pods[0]
pod_name = pod.metadata.name
containers = [c.name for c in pod.spec.containers]
init_containers = [c.name for c in pod.spec.init_containers]
container_statuses = {
c.name: c.started or c.restart_count
for c in pod.status.container_statuses
}
container = "default"
if container not in containers:
container = containers[0]
# some containers might not be running yet and have no logs to show,
# so we need to filter them out
if not container_statuses[container]:
container = init_containers[0]
logger.info(
f"Retrieving logs for pod: `{pod_name}` and container "
f"`{container}` in namespace `{self._namespace}`"
)
response = self._core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self._namespace,
container=container,
follow=follow,
tail_lines=tail,
_preload_content=False,
)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when fetching logs for SeldonDeployment resource "
"%s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching logs for SeldonDeployment "
f"resource: {name}"
) from e
try:
while True:
line = response.readline().decode("utf-8").rstrip("\n")
if not line:
return
stop = yield line
if stop:
return
finally:
response.release_conn()
def create_or_update_secret(
self,
name: str,
secret: BaseSecretSchema,
) -> None:
"""Create or update a Kubernetes Secret resource.
Uses the information contained in a ZenML secret.
Args:
name: the name of the Secret resource to create.
secret: a ZenML secret with key-values that should be
stored in the Secret resource.
Raises:
SeldonClientError: if an unknown error occurs during the creation of
the secret.
k8s_client.rest.ApiException: unexpected error.
"""
try:
logger.debug(f"Creating Secret resource: {name}")
secret_data = {
k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
"ascii"
)
for k, v in secret.content.items()
if v is not None
}
secret = k8s_client.V1Secret(
metadata=k8s_client.V1ObjectMeta(
name=name,
labels={"app": "zenml"},
),
type="Opaque",
data=secret_data,
)
try:
# check if the secret is already present
self._core_api.read_namespaced_secret(
name=name,
namespace=self._namespace,
)
# if we got this far, the secret is already present, update it
# in place
response = self._core_api.replace_namespaced_secret(
name=name,
namespace=self._namespace,
body=secret,
)
except k8s_client.rest.ApiException as e:
if e.status != 404:
# if an error other than 404 is raised here, treat it
# as an unexpected error
raise
response = self._core_api.create_namespaced_secret(
namespace=self._namespace,
body=secret,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error("Exception when creating Secret resource: %s", str(e))
raise SeldonClientError(
"Exception when creating Secret resource"
) from e
def delete_secret(
self,
name: str,
) -> None:
"""Delete a Kubernetes Secret resource managed by ZenML.
Args:
name: the name of the Kubernetes Secret resource to delete.
Raises:
SeldonClientError: if an unknown error occurs during the removal
of the secret.
"""
try:
logger.debug(f"Deleting Secret resource: {name}")
response = self._core_api.delete_namespaced_secret(
name=name,
namespace=self._namespace,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
if e.status == 404:
# the secret is no longer present, nothing to do
return
logger.error(
"Exception when deleting Secret resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting Secret resource {name}"
) from e
namespace: str
property
readonly
Returns the Kubernetes namespace in use by the client.
Returns:
Type | Description |
---|---|
str |
The Kubernetes namespace in use by the client. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the namespace has not been configured. |
__init__(self, context, namespace)
special
Initialize a Seldon Core client.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
Optional[str] |
the Kubernetes context to use. |
required |
namespace |
Optional[str] |
the Kubernetes namespace to use. |
required |
Source code in zenml/integrations/seldon/seldon_client.py
def __init__(self, context: Optional[str], namespace: Optional[str]):
"""Initialize a Seldon Core client.
Args:
context: the Kubernetes context to use.
namespace: the Kubernetes namespace to use.
"""
self._context = context
self._namespace = namespace
self._initialize_k8s_clients()
create_deployment(self, deployment, poll_timeout=0)
Create a Seldon Core deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource to create |
required |
poll_timeout |
int |
the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised. |
0 |
Returns:
Type | Description |
---|---|
SeldonDeployment |
the created Seldon Core deployment resource with updated status. |
Exceptions:
Type | Description |
---|---|
SeldonDeploymentExistsError |
if a deployment with the same name already exists. |
SeldonClientError |
if an unknown error occurs during the creation of the deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def create_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Create a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to create
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the created Seldon Core deployment resource with updated status.
Raises:
SeldonDeploymentExistsError: if a deployment with the same name
already exists.
SeldonClientError: if an unknown error occurs during the creation of
the deployment.
"""
try:
logger.debug(f"Creating SeldonDeployment resource: {deployment}")
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
body_deploy = deployment.dict(exclude_none=True)
response = self._custom_objects_api.create_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
body=body_deploy,
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when creating SeldonDeployment resource: %s", str(e)
)
if e.status == 409:
raise SeldonDeploymentExistsError(
f"A deployment with the name {deployment.name} "
f"already exists in namespace {self._namespace}"
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
created_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and created_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
created_deployment = self.get_deployment(name=deployment.name)
return created_deployment
create_or_update_secret(self, name, secret)
Create or update a Kubernetes Secret resource.
Uses the information contained in a ZenML secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Secret resource to create. |
required |
secret |
BaseSecretSchema |
a ZenML secret with key-values that should be stored in the Secret resource. |
required |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs during the creation of the secret. |
k8s_client.rest.ApiException |
unexpected error. |
Source code in zenml/integrations/seldon/seldon_client.py
def create_or_update_secret(
self,
name: str,
secret: BaseSecretSchema,
) -> None:
"""Create or update a Kubernetes Secret resource.
Uses the information contained in a ZenML secret.
Args:
name: the name of the Secret resource to create.
secret: a ZenML secret with key-values that should be
stored in the Secret resource.
Raises:
SeldonClientError: if an unknown error occurs during the creation of
the secret.
k8s_client.rest.ApiException: unexpected error.
"""
try:
logger.debug(f"Creating Secret resource: {name}")
secret_data = {
k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
"ascii"
)
for k, v in secret.content.items()
if v is not None
}
secret = k8s_client.V1Secret(
metadata=k8s_client.V1ObjectMeta(
name=name,
labels={"app": "zenml"},
),
type="Opaque",
data=secret_data,
)
try:
# check if the secret is already present
self._core_api.read_namespaced_secret(
name=name,
namespace=self._namespace,
)
# if we got this far, the secret is already present, update it
# in place
response = self._core_api.replace_namespaced_secret(
name=name,
namespace=self._namespace,
body=secret,
)
except k8s_client.rest.ApiException as e:
if e.status != 404:
# if an error other than 404 is raised here, treat it
# as an unexpected error
raise
response = self._core_api.create_namespaced_secret(
namespace=self._namespace,
body=secret,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error("Exception when creating Secret resource: %s", str(e))
raise SeldonClientError(
"Exception when creating Secret resource"
) from e
delete_deployment(self, name, force=False, poll_timeout=0)
Delete a Seldon Core deployment resource managed by ZenML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Seldon Core deployment resource to delete. |
required |
force |
bool |
if True, the deployment deletion will be forced (the graceful period will be set to zero). |
False |
poll_timeout |
int |
the maximum time to wait for the deployment to be deleted. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment still exists, this method will return and no exception will be raised. |
0 |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs during the deployment removal. |
Source code in zenml/integrations/seldon/seldon_client.py
def delete_deployment(
self,
name: str,
force: bool = False,
poll_timeout: int = 0,
) -> None:
"""Delete a Seldon Core deployment resource managed by ZenML.
Args:
name: the name of the Seldon Core deployment resource to delete.
force: if True, the deployment deletion will be forced (the graceful
period will be set to zero).
poll_timeout: the maximum time to wait for the deployment to be
deleted. If set to 0, the function will return immediately
without checking the deployment status. If a timeout
occurs and the deployment still exists, this method will
return and no exception will be raised.
Raises:
SeldonClientError: if an unknown error occurs during the deployment
removal.
"""
try:
logger.debug(f"Deleting SeldonDeployment resource: {name}")
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=name)
response = self._custom_objects_api.delete_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
_request_timeout=poll_timeout or None,
grace_period_seconds=0 if force else None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when deleting SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting SeldonDeployment resource {name}"
) from e
while poll_timeout > 0:
try:
self.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return
time.sleep(5)
poll_timeout -= 5
delete_secret(self, name)
Delete a Kubernetes Secret resource managed by ZenML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Kubernetes Secret resource to delete. |
required |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs during the removal of the secret. |
Source code in zenml/integrations/seldon/seldon_client.py
def delete_secret(
self,
name: str,
) -> None:
"""Delete a Kubernetes Secret resource managed by ZenML.
Args:
name: the name of the Kubernetes Secret resource to delete.
Raises:
SeldonClientError: if an unknown error occurs during the removal
of the secret.
"""
try:
logger.debug(f"Deleting Secret resource: {name}")
response = self._core_api.delete_namespaced_secret(
name=name,
namespace=self._namespace,
)
logger.debug("Kubernetes API response: %s", response)
except k8s_client.rest.ApiException as e:
if e.status == 404:
# the secret is no longer present, nothing to do
return
logger.error(
"Exception when deleting Secret resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Exception when deleting Secret resource {name}"
) from e
find_deployments(self, name=None, labels=None, fields=None)
Find all ZenML-managed Seldon Core deployment resources matching the given criteria.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
optional name of the deployment resource to find. |
None |
fields |
Optional[Dict[str, str]] |
optional selector to restrict the list of returned Seldon deployments by their fields. Defaults to everything. |
None |
labels |
Optional[Dict[str, str]] |
optional selector to restrict the list of returned Seldon deployments by their labels. Defaults to everything. |
None |
Returns:
Type | Description |
---|---|
List[zenml.integrations.seldon.seldon_client.SeldonDeployment] |
List of Seldon Core deployments that match the given criteria. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs while fetching the deployments. |
Source code in zenml/integrations/seldon/seldon_client.py
def find_deployments(
self,
name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
fields: Optional[Dict[str, str]] = None,
) -> List[SeldonDeployment]:
"""Find all ZenML-managed Seldon Core deployment resources matching the given criteria.
Args:
name: optional name of the deployment resource to find.
fields: optional selector to restrict the list of returned
Seldon deployments by their fields. Defaults to everything.
labels: optional selector to restrict the list of returned
Seldon deployments by their labels. Defaults to everything.
Returns:
List of Seldon Core deployments that match the given criteria.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the deployments.
"""
fields = fields or {}
labels = labels or {}
# always filter results to only include Seldon deployments managed
# by ZenML
labels["app"] = "zenml"
if name:
fields = {"metadata.name": name}
field_selector = (
",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
)
label_selector = (
",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
)
try:
logger.debug(
f"Searching SeldonDeployment resources with label selector "
f"'{labels or ''}' and field selector '{fields or ''}'"
)
response = self._custom_objects_api.list_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
field_selector=field_selector,
label_selector=label_selector,
)
logger.debug(
"Seldon Core API returned %s items", len(response["items"])
)
deployments = []
for item in response.get("items") or []:
try:
deployments.append(SeldonDeployment(**item))
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(item),
)
return deployments
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when searching SeldonDeployment resources with "
"label selector '%s' and field selector '%s': %s",
label_selector or "",
field_selector or "",
)
raise SeldonClientError(
f"Unexpected exception when searching SeldonDeployment "
f"with labels '{labels or ''}' and field '{fields or ''}'"
) from e
get_deployment(self, name)
Get a ZenML managed Seldon Core deployment resource by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Seldon Core deployment resource to fetch. |
required |
Returns:
Type | Description |
---|---|
SeldonDeployment |
The Seldon Core deployment resource. |
Exceptions:
Type | Description |
---|---|
SeldonDeploymentNotFoundError |
if the deployment resource cannot be found or is not managed by ZenML. |
SeldonClientError |
if an unknown error occurs while fetching the deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment(self, name: str) -> SeldonDeployment:
"""Get a ZenML managed Seldon Core deployment resource by name.
Args:
name: the name of the Seldon Core deployment resource to fetch.
Returns:
The Seldon Core deployment resource.
Raises:
SeldonDeploymentNotFoundError: if the deployment resource cannot
be found or is not managed by ZenML.
SeldonClientError: if an unknown error occurs while fetching
the deployment.
"""
try:
logger.debug(f"Retrieving SeldonDeployment resource: {name}")
response = self._custom_objects_api.get_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=name,
)
logger.debug("Seldon Core API response: %s", response)
try:
deployment = SeldonDeployment(**response)
except ValidationError as e:
logger.error(
"Invalid Seldon Core deployment resource: %s\n%s",
str(e),
str(response),
)
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource {name} could not be parsed"
)
# Only Seldon deployments managed by ZenML are returned
if not deployment.is_managed_by_zenml():
raise SeldonDeploymentNotFoundError(
f"Seldon Deployment {name} is not managed by ZenML"
)
return deployment
except k8s_client.rest.ApiException as e:
if e.status == 404:
raise SeldonDeploymentNotFoundError(
f"SeldonDeployment resource not found: {name}"
) from e
logger.error(
"Exception when fetching SeldonDeployment resource %s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching SeldonDeployment "
f"resource: {name}"
) from e
get_deployment_logs(self, name, follow=False, tail=None)
Get the logs of a Seldon Core deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the name of the Seldon Core deployment to get logs for. |
required |
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Yields:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
The next log line. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs while fetching the logs. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment_logs(
self,
name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core deployment resource.
Args:
name: the name of the Seldon Core deployment to get logs for.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Yields:
The next log line.
Raises:
SeldonClientError: if an unknown error occurs while fetching
the logs.
"""
logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
try:
response = self._core_api.list_namespaced_pod(
namespace=self._namespace,
label_selector=f"seldon-deployment-id={name}",
)
logger.debug("Kubernetes API response: %s", response)
pods = response.items
if not pods:
raise SeldonClientError(
f"The Seldon Core deployment {name} is not currently "
f"running: no Kubernetes pods associated with it were found"
)
pod = pods[0]
pod_name = pod.metadata.name
containers = [c.name for c in pod.spec.containers]
init_containers = [c.name for c in pod.spec.init_containers]
container_statuses = {
c.name: c.started or c.restart_count
for c in pod.status.container_statuses
}
container = "default"
if container not in containers:
container = containers[0]
# some containers might not be running yet and have no logs to show,
# so we need to filter them out
if not container_statuses[container]:
container = init_containers[0]
logger.info(
f"Retrieving logs for pod: `{pod_name}` and container "
f"`{container}` in namespace `{self._namespace}`"
)
response = self._core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self._namespace,
container=container,
follow=follow,
tail_lines=tail,
_preload_content=False,
)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when fetching logs for SeldonDeployment resource "
"%s: %s",
name,
str(e),
)
raise SeldonClientError(
f"Unexpected exception when fetching logs for SeldonDeployment "
f"resource: {name}"
) from e
try:
while True:
line = response.readline().decode("utf-8").rstrip("\n")
if not line:
return
stop = yield line
if stop:
return
finally:
response.release_conn()
sanitize_labels(labels)
staticmethod
Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
the labels to sanitize. |
required |
Source code in zenml/integrations/seldon/seldon_client.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See:
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: the labels to sanitize.
"""
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)
update_deployment(self, deployment, poll_timeout=0)
Update a Seldon Core deployment resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource to update |
required |
poll_timeout |
int |
the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised. |
0 |
Returns:
Type | Description |
---|---|
SeldonDeployment |
the updated Seldon Core deployment resource with updated status. |
Exceptions:
Type | Description |
---|---|
SeldonClientError |
if an unknown error occurs while updating the deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def update_deployment(
self,
deployment: SeldonDeployment,
poll_timeout: int = 0,
) -> SeldonDeployment:
"""Update a Seldon Core deployment resource.
Args:
deployment: the Seldon Core deployment resource to update
poll_timeout: the maximum time to wait for the deployment to become
available or to fail. If set to 0, the function will return
immediately without checking the deployment status. If a timeout
occurs and the deployment is still pending creation, it will
be returned anyway and no exception will be raised.
Returns:
the updated Seldon Core deployment resource with updated status.
Raises:
SeldonClientError: if an unknown error occurs while updating the
deployment.
"""
try:
logger.debug(
f"Updating SeldonDeployment resource: {deployment.name}"
)
# mark the deployment as managed by ZenML, to differentiate
# between deployments that are created by ZenML and those that
# are not
deployment.mark_as_managed_by_zenml()
# call `get_deployment` to check that the deployment exists
# and is managed by ZenML. It will raise
# a SeldonDeploymentNotFoundError otherwise
self.get_deployment(name=deployment.name)
response = self._custom_objects_api.patch_namespaced_custom_object(
group="machinelearning.seldon.io",
version="v1",
namespace=self._namespace,
plural="seldondeployments",
name=deployment.name,
body=deployment.dict(exclude_none=True),
_request_timeout=poll_timeout or None,
)
logger.debug("Seldon Core API response: %s", response)
except k8s_client.rest.ApiException as e:
logger.error(
"Exception when updating SeldonDeployment resource: %s", str(e)
)
raise SeldonClientError(
"Exception when creating SeldonDeployment resource"
) from e
updated_deployment = self.get_deployment(name=deployment.name)
while poll_timeout > 0 and updated_deployment.is_pending():
time.sleep(5)
poll_timeout -= 5
updated_deployment = self.get_deployment(name=deployment.name)
return updated_deployment
SeldonClientError (Exception)
Base exception class for all exceptions raised by the SeldonClient.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientError(Exception):
"""Base exception class for all exceptions raised by the SeldonClient."""
SeldonClientTimeout (SeldonClientError)
Raised when the Seldon client timed out while waiting for a resource to reach the expected status.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientTimeout(SeldonClientError):
"""Raised when the Seldon client timed out while waiting for a resource to reach the expected status."""
SeldonDeployment (BaseModel)
pydantic-model
A Seldon Core deployment CRD.
This is a Pydantic representation of some of the fields in the Seldon Core CRD (documented here: https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).
Note that not all fields are represented, only those that are relevant to the ZenML integration. The fields that are not represented are silently ignored when the Seldon Deployment is created or updated from an external SeldonDeployment CRD representation.
Attributes:
Name | Type | Description |
---|---|---|
kind |
str |
Kubernetes kind field. |
apiVersion |
str |
Kubernetes apiVersion field. |
metadata |
SeldonDeploymentMetadata |
Kubernetes metadata field. |
spec |
SeldonDeploymentSpec |
Seldon Deployment spec entry. |
status |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatus] |
Seldon Deployment status. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeployment(BaseModel):
"""A Seldon Core deployment CRD.
This is a Pydantic representation of some of the fields in the Seldon Core
CRD (documented here:
https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).
Note that not all fields are represented, only those that are relevant to
the ZenML integration. The fields that are not represented are silently
ignored when the Seldon Deployment is created or updated from an external
SeldonDeployment CRD representation.
Attributes:
kind: Kubernetes kind field.
apiVersion: Kubernetes apiVersion field.
metadata: Kubernetes metadata field.
spec: Seldon Deployment spec entry.
status: Seldon Deployment status.
"""
kind: str = Field(SELDON_DEPLOYMENT_KIND, const=True)
apiVersion: str = Field(SELDON_DEPLOYMENT_API_VERSION, const=True)
metadata: SeldonDeploymentMetadata = Field(
default_factory=SeldonDeploymentMetadata
)
spec: SeldonDeploymentSpec = Field(default_factory=SeldonDeploymentSpec)
status: Optional[SeldonDeploymentStatus]
def __str__(self) -> str:
"""Returns a string representation of the Seldon Deployment.
Returns:
A string representation of the Seldon Deployment.
"""
return json.dumps(self.dict(exclude_none=True), indent=4)
@classmethod
def build(
cls,
name: Optional[str] = None,
model_uri: Optional[str] = None,
model_name: Optional[str] = None,
implementation: Optional[str] = None,
secret_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
is_custom_deployment: Optional[bool] = False,
spec: Optional[Dict[Any, Any]] = None,
) -> "SeldonDeployment":
"""Build a basic Seldon Deployment object.
Args:
name: The name of the Seldon Deployment. If not explicitly passed,
a unique name is autogenerated.
model_uri: The URI of the model.
model_name: The name of the model.
implementation: The implementation of the model.
secret_name: The name of the Kubernetes secret containing
environment variable values (e.g. with credentials for the
artifact store) to use with the deployment service.
labels: A dictionary of labels to apply to the Seldon Deployment.
annotations: A dictionary of annotations to apply to the Seldon
Deployment.
spec: A Kubernetes pod spec to use for the Seldon Deployment.
is_custom_deployment: Whether the Seldon Deployment is a custom or a built-in one.
Returns:
A minimal SeldonDeployment object built from the provided
parameters.
"""
if not name:
name = f"zenml-{time.time()}"
if labels is None:
labels = {}
if annotations is None:
annotations = {}
if is_custom_deployment:
predictors = [
SeldonDeploymentPredictor(
name=model_name or "",
graph=SeldonDeploymentPredictiveUnit(
name="classifier",
type=SeldonDeploymentPredictiveUnitType.MODEL,
),
componentSpecs=[
SeldonDeploymentComponentSpecs(
spec=spec
# TODO [HIGH]: Add support for other component types (e.g. graph)
)
],
)
]
else:
predictors = [
SeldonDeploymentPredictor(
name=model_name or "",
graph=SeldonDeploymentPredictiveUnit(
name="classifier",
type=SeldonDeploymentPredictiveUnitType.MODEL,
modelUri=model_uri or "",
implementation=implementation or "",
envSecretRefName=secret_name,
),
)
]
return SeldonDeployment(
metadata=SeldonDeploymentMetadata(
name=name, labels=labels, annotations=annotations
),
spec=SeldonDeploymentSpec(name=name, predictors=predictors),
)
def is_managed_by_zenml(self) -> bool:
"""Checks if this Seldon Deployment is managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
Returns:
True if the Seldon Deployment is managed by ZenML, False
otherwise.
"""
return self.metadata.labels.get("app") == "zenml"
def mark_as_managed_by_zenml(self) -> None:
"""Marks this Seldon Deployment as managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
"""
self.metadata.labels["app"] = "zenml"
@property
def name(self) -> str:
"""Returns the name of this Seldon Deployment.
This is just a shortcut for `self.metadata.name`.
Returns:
The name of this Seldon Deployment.
"""
return self.metadata.name
@property
def state(self) -> SeldonDeploymentStatusState:
"""The state of the Seldon Deployment.
Returns:
The state of the Seldon Deployment.
"""
if not self.status:
return SeldonDeploymentStatusState.UNKNOWN
return self.status.state
def is_pending(self) -> bool:
"""Checks if the Seldon Deployment is in a pending state.
Returns:
True if the Seldon Deployment is pending, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.CREATING
def is_available(self) -> bool:
"""Checks if the Seldon Deployment is in an available state.
Returns:
True if the Seldon Deployment is available, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.AVAILABLE
def is_failed(self) -> bool:
"""Checks if the Seldon Deployment is in a failed state.
Returns:
True if the Seldon Deployment is failed, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.FAILED
def get_error(self) -> Optional[str]:
"""Get a message describing the error, if in an error state.
Returns:
A message describing the error, if in an error state, otherwise
None.
"""
if self.status and self.is_failed():
return self.status.description
return None
def get_pending_message(self) -> Optional[str]:
"""Get a message describing the pending conditions of the Seldon Deployment.
Returns:
A message describing the pending condition of the Seldon
Deployment, or None, if no conditions are pending.
"""
if not self.status or not self.status.conditions:
return None
ready_condition_message = [
c.message
for c in self.status.conditions
if c.type == "Ready" and not c.status
]
if not ready_condition_message:
return None
return ready_condition_message[0]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
name: str
property
readonly
Returns the name of this Seldon Deployment.
This is just a shortcut for self.metadata.name
.
Returns:
Type | Description |
---|---|
str |
The name of this Seldon Deployment. |
state: SeldonDeploymentStatusState
property
readonly
The state of the Seldon Deployment.
Returns:
Type | Description |
---|---|
SeldonDeploymentStatusState |
The state of the Seldon Deployment. |
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
__str__(self)
special
Returns a string representation of the Seldon Deployment.
Returns:
Type | Description |
---|---|
str |
A string representation of the Seldon Deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
def __str__(self) -> str:
"""Returns a string representation of the Seldon Deployment.
Returns:
A string representation of the Seldon Deployment.
"""
return json.dumps(self.dict(exclude_none=True), indent=4)
build(name=None, model_uri=None, model_name=None, implementation=None, secret_name=None, labels=None, annotations=None, is_custom_deployment=False, spec=None)
classmethod
Build a basic Seldon Deployment object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
The name of the Seldon Deployment. If not explicitly passed, a unique name is autogenerated. |
None |
model_uri |
Optional[str] |
The URI of the model. |
None |
model_name |
Optional[str] |
The name of the model. |
None |
implementation |
Optional[str] |
The implementation of the model. |
None |
secret_name |
Optional[str] |
The name of the Kubernetes secret containing environment variable values (e.g. with credentials for the artifact store) to use with the deployment service. |
None |
labels |
Optional[Dict[str, str]] |
A dictionary of labels to apply to the Seldon Deployment. |
None |
annotations |
Optional[Dict[str, str]] |
A dictionary of annotations to apply to the Seldon Deployment. |
None |
spec |
Optional[Dict[Any, Any]] |
A Kubernetes pod spec to use for the Seldon Deployment. |
None |
is_custom_deployment |
Optional[bool] |
Whether the Seldon Deployment is a custom or a built-in one. |
False |
Returns:
Type | Description |
---|---|
SeldonDeployment |
A minimal SeldonDeployment object built from the provided parameters. |
Source code in zenml/integrations/seldon/seldon_client.py
@classmethod
def build(
cls,
name: Optional[str] = None,
model_uri: Optional[str] = None,
model_name: Optional[str] = None,
implementation: Optional[str] = None,
secret_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
is_custom_deployment: Optional[bool] = False,
spec: Optional[Dict[Any, Any]] = None,
) -> "SeldonDeployment":
"""Build a basic Seldon Deployment object.
Args:
name: The name of the Seldon Deployment. If not explicitly passed,
a unique name is autogenerated.
model_uri: The URI of the model.
model_name: The name of the model.
implementation: The implementation of the model.
secret_name: The name of the Kubernetes secret containing
environment variable values (e.g. with credentials for the
artifact store) to use with the deployment service.
labels: A dictionary of labels to apply to the Seldon Deployment.
annotations: A dictionary of annotations to apply to the Seldon
Deployment.
spec: A Kubernetes pod spec to use for the Seldon Deployment.
is_custom_deployment: Whether the Seldon Deployment is a custom or a built-in one.
Returns:
A minimal SeldonDeployment object built from the provided
parameters.
"""
if not name:
name = f"zenml-{time.time()}"
if labels is None:
labels = {}
if annotations is None:
annotations = {}
if is_custom_deployment:
predictors = [
SeldonDeploymentPredictor(
name=model_name or "",
graph=SeldonDeploymentPredictiveUnit(
name="classifier",
type=SeldonDeploymentPredictiveUnitType.MODEL,
),
componentSpecs=[
SeldonDeploymentComponentSpecs(
spec=spec
# TODO [HIGH]: Add support for other component types (e.g. graph)
)
],
)
]
else:
predictors = [
SeldonDeploymentPredictor(
name=model_name or "",
graph=SeldonDeploymentPredictiveUnit(
name="classifier",
type=SeldonDeploymentPredictiveUnitType.MODEL,
modelUri=model_uri or "",
implementation=implementation or "",
envSecretRefName=secret_name,
),
)
]
return SeldonDeployment(
metadata=SeldonDeploymentMetadata(
name=name, labels=labels, annotations=annotations
),
spec=SeldonDeploymentSpec(name=name, predictors=predictors),
)
get_error(self)
Get a message describing the error, if in an error state.
Returns:
Type | Description |
---|---|
Optional[str] |
A message describing the error, if in an error state, otherwise None. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_error(self) -> Optional[str]:
"""Get a message describing the error, if in an error state.
Returns:
A message describing the error, if in an error state, otherwise
None.
"""
if self.status and self.is_failed():
return self.status.description
return None
get_pending_message(self)
Get a message describing the pending conditions of the Seldon Deployment.
Returns:
Type | Description |
---|---|
Optional[str] |
A message describing the pending condition of the Seldon Deployment, or None, if no conditions are pending. |
Source code in zenml/integrations/seldon/seldon_client.py
def get_pending_message(self) -> Optional[str]:
"""Get a message describing the pending conditions of the Seldon Deployment.
Returns:
A message describing the pending condition of the Seldon
Deployment, or None, if no conditions are pending.
"""
if not self.status or not self.status.conditions:
return None
ready_condition_message = [
c.message
for c in self.status.conditions
if c.type == "Ready" and not c.status
]
if not ready_condition_message:
return None
return ready_condition_message[0]
is_available(self)
Checks if the Seldon Deployment is in an available state.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is available, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_available(self) -> bool:
"""Checks if the Seldon Deployment is in an available state.
Returns:
True if the Seldon Deployment is available, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.AVAILABLE
is_failed(self)
Checks if the Seldon Deployment is in a failed state.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is failed, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_failed(self) -> bool:
"""Checks if the Seldon Deployment is in a failed state.
Returns:
True if the Seldon Deployment is failed, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.FAILED
is_managed_by_zenml(self)
Checks if this Seldon Deployment is managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the app
label value to zenml
.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is managed by ZenML, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_managed_by_zenml(self) -> bool:
"""Checks if this Seldon Deployment is managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
Returns:
True if the Seldon Deployment is managed by ZenML, False
otherwise.
"""
return self.metadata.labels.get("app") == "zenml"
is_pending(self)
Checks if the Seldon Deployment is in a pending state.
Returns:
Type | Description |
---|---|
bool |
True if the Seldon Deployment is pending, False otherwise. |
Source code in zenml/integrations/seldon/seldon_client.py
def is_pending(self) -> bool:
"""Checks if the Seldon Deployment is in a pending state.
Returns:
True if the Seldon Deployment is pending, False otherwise.
"""
return self.state == SeldonDeploymentStatusState.CREATING
mark_as_managed_by_zenml(self)
Marks this Seldon Deployment as managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the app
label value to zenml
.
Source code in zenml/integrations/seldon/seldon_client.py
def mark_as_managed_by_zenml(self) -> None:
"""Marks this Seldon Deployment as managed by ZenML.
The convention used to differentiate between SeldonDeployment instances
that are managed by ZenML and those that are not is to set the `app`
label value to `zenml`.
"""
self.metadata.labels["app"] = "zenml"
SeldonDeploymentComponentSpecs (BaseModel)
pydantic-model
Component specs for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
spec |
Optional[Dict[str, Any]] |
the component spec. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentComponentSpecs(BaseModel):
"""Component specs for a Seldon Deployment.
Attributes:
spec: the component spec.
"""
spec: Optional[Dict[str, Any]]
# TODO [HIGH]: Add graph field to ComponentSpecs. graph: Optional[SeldonDeploymentPredictiveUnit]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentExistsError (SeldonClientError)
Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentExistsError(SeldonClientError):
"""Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists."""
SeldonDeploymentMetadata (BaseModel)
pydantic-model
Metadata for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the Seldon Deployment. |
labels |
Dict[str, str] |
Kubernetes labels for the Seldon Deployment. |
annotations |
Dict[str, str] |
Kubernetes annotations for the Seldon Deployment. |
creationTimestamp |
Optional[str] |
the creation timestamp of the Seldon Deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentMetadata(BaseModel):
"""Metadata for a Seldon Deployment.
Attributes:
name: the name of the Seldon Deployment.
labels: Kubernetes labels for the Seldon Deployment.
annotations: Kubernetes annotations for the Seldon Deployment.
creationTimestamp: the creation timestamp of the Seldon Deployment.
"""
name: str
labels: Dict[str, str] = Field(default_factory=dict)
annotations: Dict[str, str] = Field(default_factory=dict)
creationTimestamp: Optional[str]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentNotFoundError (SeldonClientError)
Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentNotFoundError(SeldonClientError):
"""Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML."""
SeldonDeploymentPredictiveUnit (BaseModel)
pydantic-model
Seldon Deployment predictive unit.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the predictive unit. |
type |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnitType] |
predictive unit type. |
implementation |
Optional[str] |
the Seldon Core implementation used to serve the model. |
modelUri |
Optional[str] |
URI of the model (or models) to serve. |
serviceAccountName |
Optional[str] |
the name of the service account to associate with the predictive unit container. |
envSecretRefName |
Optional[str] |
the name of a Kubernetes secret that contains environment variables (e.g. credentials) to be configured for the predictive unit container. |
children |
List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnit] |
a list of child predictive units that together make up the model serving graph. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnit(BaseModel):
"""Seldon Deployment predictive unit.
Attributes:
name: the name of the predictive unit.
type: predictive unit type.
implementation: the Seldon Core implementation used to serve the model.
modelUri: URI of the model (or models) to serve.
serviceAccountName: the name of the service account to associate with
the predictive unit container.
envSecretRefName: the name of a Kubernetes secret that contains
environment variables (e.g. credentials) to be configured for the
predictive unit container.
children: a list of child predictive units that together make up the
model serving graph.
"""
name: str
type: Optional[
SeldonDeploymentPredictiveUnitType
] = SeldonDeploymentPredictiveUnitType.MODEL
implementation: Optional[str]
modelUri: Optional[str]
serviceAccountName: Optional[str]
envSecretRefName: Optional[str]
children: List["SeldonDeploymentPredictiveUnit"] = Field(
default_factory=list
)
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentPredictiveUnitType (StrEnum)
Predictive unit types for a Seldon Deployment.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnitType(StrEnum):
"""Predictive unit types for a Seldon Deployment."""
UNKNOWN_TYPE = "UNKNOWN_TYPE"
ROUTER = "ROUTER"
COMBINER = "COMBINER"
MODEL = "MODEL"
TRANSFORMER = "TRANSFORMER"
OUTPUT_TRANSFORMER = "OUTPUT_TRANSFORMER"
SeldonDeploymentPredictor (BaseModel)
pydantic-model
Seldon Deployment predictor.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the predictor. |
replicas |
int |
the number of pod replicas for the predictor. |
graph |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnit] |
the serving graph composed of one or more predictive units. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictor(BaseModel):
"""Seldon Deployment predictor.
Attributes:
name: the name of the predictor.
replicas: the number of pod replicas for the predictor.
graph: the serving graph composed of one or more predictive units.
"""
name: str
replicas: int = 1
graph: Optional[SeldonDeploymentPredictiveUnit] = Field(
default_factory=SeldonDeploymentPredictiveUnit
)
componentSpecs: Optional[List[SeldonDeploymentComponentSpecs]]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentSpec (BaseModel)
pydantic-model
Spec for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
the name of the Seldon Deployment. |
protocol |
Optional[str] |
the API protocol used for the Seldon Deployment. |
predictors |
List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictor] |
a list of predictors that make up the serving graph. |
replicas |
int |
the default number of pod replicas used for the predictors. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentSpec(BaseModel):
"""Spec for a Seldon Deployment.
Attributes:
name: the name of the Seldon Deployment.
protocol: the API protocol used for the Seldon Deployment.
predictors: a list of predictors that make up the serving graph.
replicas: the default number of pod replicas used for the predictors.
"""
name: str
protocol: Optional[str]
predictors: List[SeldonDeploymentPredictor]
replicas: int = 1
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentStatus (BaseModel)
pydantic-model
The status of a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
state |
SeldonDeploymentStatusState |
the current state of the Seldon Deployment. |
description |
Optional[str] |
a human-readable description of the current state. |
replicas |
Optional[int] |
the current number of running pod replicas |
address |
Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusAddress] |
the address where the Seldon Deployment API can be accessed. |
conditions |
List[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusCondition] |
the list of Kubernetes conditions for the Seldon Deployment. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatus(BaseModel):
"""The status of a Seldon Deployment.
Attributes:
state: the current state of the Seldon Deployment.
description: a human-readable description of the current state.
replicas: the current number of running pod replicas
address: the address where the Seldon Deployment API can be accessed.
conditions: the list of Kubernetes conditions for the Seldon Deployment.
"""
state: SeldonDeploymentStatusState = SeldonDeploymentStatusState.UNKNOWN
description: Optional[str]
replicas: Optional[int]
address: Optional[SeldonDeploymentStatusAddress]
conditions: List[SeldonDeploymentStatusCondition]
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/integrations/seldon/seldon_client.py
class Config:
"""Pydantic configuration class."""
# validate attribute assignments
validate_assignment = True
# Ignore extra attributes from the CRD that are not reflected here
extra = "ignore"
SeldonDeploymentStatusAddress (BaseModel)
pydantic-model
The status address for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
url |
str |
the URL where the Seldon Deployment API can be accessed internally. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusAddress(BaseModel):
"""The status address for a Seldon Deployment.
Attributes:
url: the URL where the Seldon Deployment API can be accessed internally.
"""
url: str
SeldonDeploymentStatusCondition (BaseModel)
pydantic-model
The Kubernetes status condition entry for a Seldon Deployment.
Attributes:
Name | Type | Description |
---|---|---|
type |
str |
Type of runtime condition. |
status |
bool |
Status of the condition. |
reason |
Optional[str] |
Brief CamelCase string containing reason for the condition's last transition. |
message |
Optional[str] |
Human-readable message indicating details about last transition. |
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusCondition(BaseModel):
"""The Kubernetes status condition entry for a Seldon Deployment.
Attributes:
type: Type of runtime condition.
status: Status of the condition.
reason: Brief CamelCase string containing reason for the condition's
last transition.
message: Human-readable message indicating details about last
transition.
"""
type: str
status: bool
reason: Optional[str]
message: Optional[str]
SeldonDeploymentStatusState (StrEnum)
Possible state values for a Seldon Deployment.
Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusState(StrEnum):
"""Possible state values for a Seldon Deployment."""
UNKNOWN = "Unknown"
AVAILABLE = "Available"
CREATING = "Creating"
FAILED = "Failed"
create_seldon_core_custom_spec(model_uri, custom_docker_image, secret_name, command, container_registry_secret_name=None)
Create a custom pod spec for the seldon core container.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
Optional[str] |
The URI of the model to load. |
required |
custom_docker_image |
Optional[str] |
The docker image to use. |
required |
secret_name |
Optional[str] |
The name of the secret to use. |
required |
command |
Optional[List[str]] |
The command to run in the container. |
required |
container_registry_secret_name |
Optional[str] |
The name of the secret to use for docker image pull. |
None |
Returns:
Type | Description |
---|---|
V1PodSpec |
A pod spec for the seldon core container. |
Source code in zenml/integrations/seldon/seldon_client.py
def create_seldon_core_custom_spec(
model_uri: Optional[str],
custom_docker_image: Optional[str],
secret_name: Optional[str],
command: Optional[List[str]],
container_registry_secret_name: Optional[str] = None,
) -> k8s_client.V1PodSpec:
"""Create a custom pod spec for the seldon core container.
Args:
model_uri: The URI of the model to load.
custom_docker_image: The docker image to use.
secret_name: The name of the secret to use.
command: The command to run in the container.
container_registry_secret_name: The name of the secret to use for docker image pull.
Returns:
A pod spec for the seldon core container.
"""
volume = k8s_client.V1Volume(
name="classifier-provision-location",
empty_dir={},
)
init_container = k8s_client.V1Container(
name="classifier-model-initializer",
image="seldonio/rclone-storage-initializer:1.14.0-dev",
image_pull_policy="IfNotPresent",
args=[model_uri, "/mnt/models"],
volume_mounts=[
k8s_client.V1VolumeMount(
name="classifier-provision-location", mount_path="/mnt/models"
)
],
env_from=[
k8s_client.V1EnvFromSource(
secret_ref=k8s_client.V1SecretEnvSource(
name=secret_name, optional=False
)
)
],
)
image_pull_secret = k8s_client.V1LocalObjectReference(
name=container_registry_secret_name
)
container = k8s_client.V1Container(
name="classifier",
image=custom_docker_image,
image_pull_policy="IfNotPresent",
command=command,
volume_mounts=[
k8s_client.V1VolumeMount(
name="classifier-provision-location",
mount_path="/mnt/models",
read_only=True,
)
],
ports=[
k8s_client.V1ContainerPort(container_port=5000),
k8s_client.V1ContainerPort(container_port=9000),
],
)
if image_pull_secret:
spec = k8s_client.V1PodSpec(
volumes=[
volume,
],
init_containers=[
init_container,
],
image_pull_secrets=[image_pull_secret],
containers=[container],
)
else:
spec = k8s_client.V1PodSpec(
volumes=[
volume,
],
init_containers=[
init_container,
],
containers=[container],
)
return api.sanitize_for_serialization(spec)
services
special
Initialization for Seldon services.
seldon_deployment
Implementation for the Seldon Deployment step.
SeldonDeploymentConfig (ServiceConfig)
pydantic-model
Seldon Core deployment service configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the model (or models) to serve. |
model_name |
str |
the name of the model. Multiple versions of the same model should use the same model name. |
implementation |
str |
the Seldon Core implementation used to serve the model. |
replicas |
int |
number of replicas to use for the prediction service. |
secret_name |
Optional[str] |
the name of a Kubernetes secret containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store). |
model_metadata |
Dict[str, Any] |
optional model metadata information (see https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html). |
extra_args |
Dict[str, Any] |
additional arguments to pass to the Seldon Core deployment resource configuration. |
is_custom_deployment |
Optional[bool] |
whether the deployment is a custom deployment |
spec |
Optional[Dict[Any, Any]] |
custom Kubernetes resource specification for the Seldon Core |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentConfig(ServiceConfig):
"""Seldon Core deployment service configuration.
Attributes:
model_uri: URI of the model (or models) to serve.
model_name: the name of the model. Multiple versions of the same model
should use the same model name.
implementation: the Seldon Core implementation used to serve the model.
replicas: number of replicas to use for the prediction service.
secret_name: the name of a Kubernetes secret containing additional
configuration parameters for the Seldon Core deployment (e.g.
credentials to access the Artifact Store).
model_metadata: optional model metadata information (see
https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html).
extra_args: additional arguments to pass to the Seldon Core deployment
resource configuration.
is_custom_deployment: whether the deployment is a custom deployment
spec: custom Kubernetes resource specification for the Seldon Core
"""
model_uri: str = ""
model_name: str = "default"
# TODO [ENG-775]: have an enum of all supported Seldon Core implementations
implementation: str
replicas: int = 1
secret_name: Optional[str]
model_metadata: Dict[str, Any] = Field(default_factory=dict)
extra_args: Dict[str, Any] = Field(default_factory=dict)
is_custom_deployment: Optional[bool] = False
spec: Optional[Dict[Any, Any]] = Field(default_factory=dict)
def get_seldon_deployment_labels(self) -> Dict[str, str]:
"""Generate labels for the Seldon Core deployment from the service configuration.
These labels are attached to the Seldon Core deployment resource
and may be used as label selectors in lookup operations.
Returns:
The labels for the Seldon Core deployment.
"""
labels = {}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.implementation:
labels["zenml.model_type"] = self.implementation
SeldonClient.sanitize_labels(labels)
return labels
def get_seldon_deployment_annotations(self) -> Dict[str, str]:
"""Generate annotations for the Seldon Core deployment from the service configuration.
The annotations are used to store additional information about the
Seldon Core service that is associated with the deployment that is
not available in the labels. One annotation particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote Seldon deployment.
Returns:
The annotations for the Seldon Core deployment.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentConfig":
"""Recreate the configuration of a Seldon Core Service from a deployed instance.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service configuration corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible Seldon Core service configuration.
"""
config_data = deployment.metadata.annotations.get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded Seldon Core deployment resource contains an "
f"invalid or incompatible Seldon Core service configuration: "
f"{config_data}"
) from e
return service_config
create_from_deployment(deployment)
classmethod
Recreate the configuration of a Seldon Core Service from a deployed instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource. |
required |
Returns:
Type | Description |
---|---|
SeldonDeploymentConfig |
The Seldon Core service configuration corresponding to the given Seldon Core deployment resource. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible Seldon Core service configuration. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentConfig":
"""Recreate the configuration of a Seldon Core Service from a deployed instance.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service configuration corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected annotations or it contains an invalid or
incompatible Seldon Core service configuration.
"""
config_data = deployment.metadata.annotations.get(
"zenml.service_config"
)
if not config_data:
raise ValueError(
f"The given deployment resource does not contain a "
f"'zenml.service_config' annotation: {deployment}"
)
try:
service_config = cls.parse_raw(config_data)
except ValidationError as e:
raise ValueError(
f"The loaded Seldon Core deployment resource contains an "
f"invalid or incompatible Seldon Core service configuration: "
f"{config_data}"
) from e
return service_config
get_seldon_deployment_annotations(self)
Generate annotations for the Seldon Core deployment from the service configuration.
The annotations are used to store additional information about the Seldon Core service that is associated with the deployment that is not available in the labels. One annotation particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote Seldon deployment.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The annotations for the Seldon Core deployment. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_annotations(self) -> Dict[str, str]:
"""Generate annotations for the Seldon Core deployment from the service configuration.
The annotations are used to store additional information about the
Seldon Core service that is associated with the deployment that is
not available in the labels. One annotation particularly important
is the serialized Service configuration itself, which is used to
recreate the service configuration from a remote Seldon deployment.
Returns:
The annotations for the Seldon Core deployment.
"""
annotations = {
"zenml.service_config": self.json(),
"zenml.version": __version__,
}
return annotations
get_seldon_deployment_labels(self)
Generate labels for the Seldon Core deployment from the service configuration.
These labels are attached to the Seldon Core deployment resource and may be used as label selectors in lookup operations.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The labels for the Seldon Core deployment. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_labels(self) -> Dict[str, str]:
"""Generate labels for the Seldon Core deployment from the service configuration.
These labels are attached to the Seldon Core deployment resource
and may be used as label selectors in lookup operations.
Returns:
The labels for the Seldon Core deployment.
"""
labels = {}
if self.pipeline_name:
labels["zenml.pipeline_name"] = self.pipeline_name
if self.pipeline_run_id:
labels["zenml.pipeline_run_id"] = self.pipeline_run_id
if self.pipeline_step_name:
labels["zenml.pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml.model_name"] = self.model_name
if self.model_uri:
labels["zenml.model_uri"] = self.model_uri
if self.implementation:
labels["zenml.model_type"] = self.implementation
SeldonClient.sanitize_labels(labels)
return labels
SeldonDeploymentService (BaseService)
pydantic-model
A service that represents a Seldon Core deployment server.
Attributes:
Name | Type | Description |
---|---|---|
config |
SeldonDeploymentConfig |
service configuration. |
status |
SeldonDeploymentServiceStatus |
service status. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentService(BaseService):
"""A service that represents a Seldon Core deployment server.
Attributes:
config: service configuration.
status: service status.
"""
SERVICE_TYPE = ServiceType(
name="seldon-deployment",
type="model-serving",
flavor="seldon",
description="Seldon Core prediction service",
)
config: SeldonDeploymentConfig = Field(
default_factory=SeldonDeploymentConfig
)
status: SeldonDeploymentServiceStatus = Field(
default_factory=SeldonDeploymentServiceStatus
)
def _get_client(self) -> SeldonClient:
"""Get the Seldon Core client from the active Seldon Core model deployer.
Returns:
The Seldon Core client.
"""
from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
SeldonModelDeployer,
)
model_deployer = SeldonModelDeployer.get_active_model_deployer()
return model_deployer.seldon_client
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the the current operational state of the Seldon Core deployment.
Returns:
The operational state of the Seldon Core deployment and a message
providing additional information about that state (e.g. a
description of the error, if one is encountered).
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
deployment = client.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return (ServiceState.INACTIVE, "")
if deployment.is_available():
return (
ServiceState.ACTIVE,
f"Seldon Core deployment '{name}' is available",
)
if deployment.is_failed():
return (
ServiceState.ERROR,
f"Seldon Core deployment '{name}' failed: "
f"{deployment.get_error()}",
)
pending_message = deployment.get_pending_message() or ""
return (
ServiceState.PENDING_STARTUP,
"Seldon Core deployment is being created: " + pending_message,
)
@property
def seldon_deployment_name(self) -> str:
"""Get the name of the Seldon Core deployment.
It should return the one that uniquely corresponds to this service instance.
Returns:
The name of the Seldon Core deployment.
"""
return f"zenml-{str(self.uuid)}"
def _get_seldon_deployment_labels(self) -> Dict[str, str]:
"""Generate the labels for the Seldon Core deployment from the service configuration.
Returns:
The labels for the Seldon Core deployment.
"""
labels = self.config.get_seldon_deployment_labels()
labels["zenml.service_uuid"] = str(self.uuid)
SeldonClient.sanitize_labels(labels)
return labels
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentService":
"""Recreate a Seldon Core service from a Seldon Core deployment resource.
It should then update their operational status.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected service_uuid label.
"""
config = SeldonDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.labels.get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
def provision(self) -> None:
"""Provision or update remote Seldon Core deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
name = self.seldon_deployment_name
deployment = SeldonDeployment.build(
name=name,
model_uri=self.config.model_uri,
model_name=self.config.model_name,
implementation=self.config.implementation,
secret_name=self.config.secret_name,
labels=self._get_seldon_deployment_labels(),
annotations=self.config.get_seldon_deployment_annotations(),
is_custom_deployment=self.config.is_custom_deployment,
spec=self.config.spec,
)
deployment.spec.replicas = self.config.replicas
deployment.spec.predictors[0].replicas = self.config.replicas
# check if the Seldon deployment already exists
try:
client.get_deployment(name=name)
# update the existing deployment
client.update_deployment(deployment)
except SeldonDeploymentNotFoundError:
# create the deployment
client.create_deployment(deployment=deployment)
def deprovision(self, force: bool = False) -> None:
"""Deprovision the remote Seldon Core deployment instance.
Args:
force: if True, the remote deployment instance will be
forcefully deprovisioned.
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
client.delete_deployment(name=name, force=force)
except SeldonDeploymentNotFoundError:
pass
def get_logs(
self,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core model deployment.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_client().get_deployment_logs(
self.seldon_deployment_name,
follow=follow,
tail=tail,
)
@property
def prediction_url(self) -> Optional[str]:
"""The prediction URI exposed by the prediction service.
Returns:
The prediction URI exposed by the prediction service, or None if
the service is not yet ready.
"""
from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
SeldonModelDeployer,
)
if not self.is_running:
return None
namespace = self._get_client().namespace
model_deployer = SeldonModelDeployer.get_active_model_deployer()
return os.path.join(
model_deployer.config.base_url,
"seldon",
namespace,
self.seldon_deployment_name,
"api/v0.1/predictions",
)
def predict(self, request: str) -> Any:
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"Seldon prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
if isinstance(request, str):
request = json.loads(request)
else:
raise ValueError("Request must be a json string.")
response = requests.post(
self.prediction_url,
json={"data": {"ndarray": request}},
)
response.raise_for_status()
return response.json()
prediction_url: Optional[str]
property
readonly
The prediction URI exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction URI exposed by the prediction service, or None if the service is not yet ready. |
seldon_deployment_name: str
property
readonly
Get the name of the Seldon Core deployment.
It should return the one that uniquely corresponds to this service instance.
Returns:
Type | Description |
---|---|
str |
The name of the Seldon Core deployment. |
check_status(self)
Check the the current operational state of the Seldon Core deployment.
Returns:
Type | Description |
---|---|
Tuple[zenml.services.service_status.ServiceState, str] |
The operational state of the Seldon Core deployment and a message providing additional information about that state (e.g. a description of the error, if one is encountered). |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the the current operational state of the Seldon Core deployment.
Returns:
The operational state of the Seldon Core deployment and a message
providing additional information about that state (e.g. a
description of the error, if one is encountered).
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
deployment = client.get_deployment(name=name)
except SeldonDeploymentNotFoundError:
return (ServiceState.INACTIVE, "")
if deployment.is_available():
return (
ServiceState.ACTIVE,
f"Seldon Core deployment '{name}' is available",
)
if deployment.is_failed():
return (
ServiceState.ERROR,
f"Seldon Core deployment '{name}' failed: "
f"{deployment.get_error()}",
)
pending_message = deployment.get_pending_message() or ""
return (
ServiceState.PENDING_STARTUP,
"Seldon Core deployment is being created: " + pending_message,
)
create_from_deployment(deployment)
classmethod
Recreate a Seldon Core service from a Seldon Core deployment resource.
It should then update their operational status.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
SeldonDeployment |
the Seldon Core deployment resource. |
required |
Returns:
Type | Description |
---|---|
SeldonDeploymentService |
The Seldon Core service corresponding to the given Seldon Core deployment resource. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the given deployment resource does not contain the expected service_uuid label. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
cls, deployment: SeldonDeployment
) -> "SeldonDeploymentService":
"""Recreate a Seldon Core service from a Seldon Core deployment resource.
It should then update their operational status.
Args:
deployment: the Seldon Core deployment resource.
Returns:
The Seldon Core service corresponding to the given
Seldon Core deployment resource.
Raises:
ValueError: if the given deployment resource does not contain
the expected service_uuid label.
"""
config = SeldonDeploymentConfig.create_from_deployment(deployment)
uuid = deployment.metadata.labels.get("zenml.service_uuid")
if not uuid:
raise ValueError(
f"The given deployment resource does not contain a valid "
f"'zenml.service_uuid' label: {deployment}"
)
service = cls(uuid=UUID(uuid), config=config)
service.update_status()
return service
deprovision(self, force=False)
Deprovision the remote Seldon Core deployment instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
force |
bool |
if True, the remote deployment instance will be forcefully deprovisioned. |
False |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def deprovision(self, force: bool = False) -> None:
"""Deprovision the remote Seldon Core deployment instance.
Args:
force: if True, the remote deployment instance will be
forcefully deprovisioned.
"""
client = self._get_client()
name = self.seldon_deployment_name
try:
client.delete_deployment(name=name, force=force)
except SeldonDeploymentNotFoundError:
pass
get_logs(self, follow=False, tail=None)
Get the logs of a Seldon Core model deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_logs(
self,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Get the logs of a Seldon Core model deployment.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
return self._get_client().get_deployment_logs(
self.seldon_deployment_name,
follow=follow,
tail=tail,
)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
str |
a numpy array representing the request |
required |
Returns:
Type | Description |
---|---|
Any |
A numpy array representing the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not yet ready. |
ValueError |
if the prediction_url is not set. |
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def predict(self, request: str) -> Any:
"""Make a prediction using the service.
Args:
request: a numpy array representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not yet ready.
ValueError: if the prediction_url is not set.
"""
if not self.is_running:
raise Exception(
"Seldon prediction service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is None:
raise ValueError("`self.prediction_url` is not set, cannot post.")
if isinstance(request, str):
request = json.loads(request)
else:
raise ValueError("Request must be a json string.")
response = requests.post(
self.prediction_url,
json={"data": {"ndarray": request}},
)
response.raise_for_status()
return response.json()
provision(self)
Provision or update remote Seldon Core deployment instance.
This should then match the current configuration.
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def provision(self) -> None:
"""Provision or update remote Seldon Core deployment instance.
This should then match the current configuration.
"""
client = self._get_client()
name = self.seldon_deployment_name
deployment = SeldonDeployment.build(
name=name,
model_uri=self.config.model_uri,
model_name=self.config.model_name,
implementation=self.config.implementation,
secret_name=self.config.secret_name,
labels=self._get_seldon_deployment_labels(),
annotations=self.config.get_seldon_deployment_annotations(),
is_custom_deployment=self.config.is_custom_deployment,
spec=self.config.spec,
)
deployment.spec.replicas = self.config.replicas
deployment.spec.predictors[0].replicas = self.config.replicas
# check if the Seldon deployment already exists
try:
client.get_deployment(name=name)
# update the existing deployment
client.update_deployment(deployment)
except SeldonDeploymentNotFoundError:
# create the deployment
client.create_deployment(deployment=deployment)
SeldonDeploymentServiceStatus (ServiceStatus)
pydantic-model
Seldon Core deployment service status.
Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentServiceStatus(ServiceStatus):
"""Seldon Core deployment service status."""
steps
special
Initialization for Seldon steps.
seldon_deployer
Implementation of the Seldon Deployer step.
CustomDeployParameters (BaseModel)
pydantic-model
Custom model deployer step extra parameters.
Attributes:
Name | Type | Description |
---|---|---|
predict_function |
str |
Path to Python file containing predict function. |
Exceptions:
Type | Description |
---|---|
ValueError |
If predict_function is not specified. |
TypeError |
If predict_function is not a callable function. |
Returns:
Type | Description |
---|---|
predict_function |
Path to Python file containing predict function. |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class CustomDeployParameters(BaseModel):
"""Custom model deployer step extra parameters.
Attributes:
predict_function: Path to Python file containing predict function.
Raises:
ValueError: If predict_function is not specified.
TypeError: If predict_function is not a callable function.
Returns:
predict_function: Path to Python file containing predict function.
"""
predict_function: str
@validator("predict_function")
def predict_function_validate(cls, predict_func_path: str) -> str:
"""Validate predict function.
Args:
predict_func_path: predict function path
Returns:
predict function path
Raises:
ValueError: if predict function path is not valid
TypeError: if predict function path is not a callable function
"""
try:
predict_function = import_class_by_path(predict_func_path)
except AttributeError:
raise ValueError("Predict function can't be found.")
if not callable(predict_function):
raise TypeError("Predict function must be callable.")
return predict_func_path
predict_function_validate(predict_func_path)
classmethod
Validate predict function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
predict_func_path |
str |
predict function path |
required |
Returns:
Type | Description |
---|---|
str |
predict function path |
Exceptions:
Type | Description |
---|---|
ValueError |
if predict function path is not valid |
TypeError |
if predict function path is not a callable function |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
@validator("predict_function")
def predict_function_validate(cls, predict_func_path: str) -> str:
"""Validate predict function.
Args:
predict_func_path: predict function path
Returns:
predict function path
Raises:
ValueError: if predict function path is not valid
TypeError: if predict function path is not a callable function
"""
try:
predict_function = import_class_by_path(predict_func_path)
except AttributeError:
raise ValueError("Predict function can't be found.")
if not callable(predict_function):
raise TypeError("Predict function must be callable.")
return predict_func_path
SeldonDeployerStepParameters (BaseParameters)
pydantic-model
Seldon model deployer step parameters.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
SeldonDeploymentConfig |
Seldon Core deployment service configuration. |
secrets |
a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables. |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepParameters(BaseParameters):
"""Seldon model deployer step parameters.
Attributes:
service_config: Seldon Core deployment service configuration.
secrets: a list of ZenML secrets containing additional configuration
parameters for the Seldon Core deployment (e.g. credentials to
access the Artifact Store where the models are stored). If supplied,
the information fetched from these secrets is passed to the Seldon
Core deployment server as a list of environment variables.
"""
service_config: SeldonDeploymentConfig
custom_deploy_parameters: Optional[CustomDeployParameters] = None
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
seldon_custom_model_deployer_step (BaseStep)
Seldon Core custom model deployer pipeline step.
This step can be used in a pipeline to implement the the process required to deploy a custom model with Seldon Core.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
params |
parameters for the deployer step |
required | |
model |
the model artifact to deploy |
required | |
context |
the step context |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the custom deployer is not defined |
DoesNotExistException |
if an entity does not exist raise an exception |
Returns:
Type | Description |
---|---|
Seldon Core deployment service |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Seldon model deployer step parameters.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
SeldonDeploymentConfig |
Seldon Core deployment service configuration. |
secrets |
a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables. |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepParameters(BaseParameters):
"""Seldon model deployer step parameters.
Attributes:
service_config: Seldon Core deployment service configuration.
secrets: a list of ZenML secrets containing additional configuration
parameters for the Seldon Core deployment (e.g. credentials to
access the Artifact Store where the models are stored). If supplied,
the information fetched from these secrets is passed to the Seldon
Core deployment server as a list of environment variables.
"""
service_config: SeldonDeploymentConfig
custom_deploy_parameters: Optional[CustomDeployParameters] = None
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, params, context, model)
staticmethod
Seldon Core custom model deployer pipeline step.
This step can be used in a pipeline to implement the the process required to deploy a custom model with Seldon Core.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
params |
SeldonDeployerStepParameters |
parameters for the deployer step |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
context |
StepContext |
the step context |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the custom deployer is not defined |
DoesNotExistException |
if an entity does not exist raise an exception |
Returns:
Type | Description |
---|---|
SeldonDeploymentService |
Seldon Core deployment service |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
@step(enable_cache=False)
def seldon_custom_model_deployer_step(
deploy_decision: bool,
params: SeldonDeployerStepParameters,
context: StepContext,
model: ModelArtifact,
) -> SeldonDeploymentService:
"""Seldon Core custom model deployer pipeline step.
This step can be used in a pipeline to implement the
the process required to deploy a custom model with Seldon Core.
Args:
deploy_decision: whether to deploy the model or not
params: parameters for the deployer step
model: the model artifact to deploy
context: the step context
Raises:
ValueError: if the custom deployer is not defined
DoesNotExistException: if an entity does not exist raise an exception
Returns:
Seldon Core deployment service
"""
# verify that a custom deployer is defined
if not params.custom_deploy_parameters:
raise ValueError(
"Custom deploy parameter is required as part of the step configuration this parameter is",
"the path of the custom predict function",
)
# get the active model deployer
model_deployer = SeldonModelDeployer.get_active_model_deployer()
# get pipeline name, step name, run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
pipeline_run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# update the step configuration with the real pipeline runtime information
params.service_config.pipeline_name = pipeline_name
params.service_config.pipeline_run_id = pipeline_run_id
params.service_config.pipeline_step_name = step_name
params.service_config.is_custom_deployment = True
# fetch existing services with the same pipeline name, step name and
# model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=params.service_config.model_name,
)
# even when the deploy decision is negative if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not"
f" meet the criteria. Reusing the last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{params.service_config.model_name}'..."
)
service = cast(SeldonDeploymentService, existing_services[0])
# even when the deployment decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure that
# a model server is available at all times
if not service.is_running:
service.start(timeout=params.timeout)
return service
# entrypoint for starting Seldon microservice deployment for custom model
entrypoint_command = [
"python",
"-m",
"zenml.integrations.seldon.custom_deployer.zenml_custom_model",
"--model_name",
params.service_config.model_name,
"--predict_func",
params.custom_deploy_parameters.predict_function,
]
# verify if there is an active stack before starting the service
if not context.stack:
raise DoesNotExistException(
"No active stack is available. "
"Please make sure that you have registered and set a stack."
)
context.stack
docker_image = step_env.step_run_info.pipeline.extra[
SELDON_DOCKER_IMAGE_KEY
]
# copy the model files to new specific directory for the deployment
served_model_uri = os.path.join(context.get_output_artifact_uri(), "seldon")
fileio.makedirs(served_model_uri)
io_utils.copy_dir(model.uri, served_model_uri)
# Get the model artifact to extract information about the model
# and how it can be loaded again later in the deployment environment.
artifact = Client().zen_store.list_artifacts(artifact_uri=model.uri)
if not artifact:
raise DoesNotExistException("No artifact found at {}".format(model.uri))
# save the model artifact metadata to the YAML file and copy it to the
# deployment directory
model_metadata_file = save_model_metadata(artifact[0])
fileio.copy(
model_metadata_file,
os.path.join(served_model_uri, MODEL_METADATA_YAML_FILE_NAME),
)
# prepare the service configuration for the deployment
service_config = params.service_config.copy()
service_config.model_uri = served_model_uri
# create the specification for the custom deployment
service_config.spec = create_seldon_core_custom_spec(
model_uri=service_config.model_uri,
custom_docker_image=docker_image,
secret_name=model_deployer.kubernetes_secret_name,
command=entrypoint_command,
)
# deploy the service
service = cast(
SeldonDeploymentService,
model_deployer.deploy_model(
service_config, replace=True, timeout=params.timeout
),
)
logger.info(
f"Seldon Core deployment service started and reachable at:\n"
f" {service.prediction_url}\n"
)
return service
seldon_model_deployer_step (BaseStep)
Seldon Core model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
whether to deploy the model or not |
required | |
params |
parameters for the deployer step |
required | |
model |
the model artifact to deploy |
required | |
context |
the step context |
required |
Returns:
Type | Description |
---|---|
Seldon Core deployment service |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Seldon model deployer step parameters.
Attributes:
Name | Type | Description |
---|---|---|
service_config |
SeldonDeploymentConfig |
Seldon Core deployment service configuration. |
secrets |
a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables. |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepParameters(BaseParameters):
"""Seldon model deployer step parameters.
Attributes:
service_config: Seldon Core deployment service configuration.
secrets: a list of ZenML secrets containing additional configuration
parameters for the Seldon Core deployment (e.g. credentials to
access the Artifact Store where the models are stored). If supplied,
the information fetched from these secrets is passed to the Seldon
Core deployment server as a list of environment variables.
"""
service_config: SeldonDeploymentConfig
custom_deploy_parameters: Optional[CustomDeployParameters] = None
timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, params, context, model)
staticmethod
Seldon Core model deployer pipeline step.
This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deploy_decision |
bool |
whether to deploy the model or not |
required |
params |
SeldonDeployerStepParameters |
parameters for the deployer step |
required |
model |
ModelArtifact |
the model artifact to deploy |
required |
context |
StepContext |
the step context |
required |
Returns:
Type | Description |
---|---|
SeldonDeploymentService |
Seldon Core deployment service |
Source code in zenml/integrations/seldon/steps/seldon_deployer.py
@step(enable_cache=False)
def seldon_model_deployer_step(
deploy_decision: bool,
params: SeldonDeployerStepParameters,
context: StepContext,
model: ModelArtifact,
) -> SeldonDeploymentService:
"""Seldon Core model deployer pipeline step.
This step can be used in a pipeline to implement continuous
deployment for a ML model with Seldon Core.
Args:
deploy_decision: whether to deploy the model or not
params: parameters for the deployer step
model: the model artifact to deploy
context: the step context
Returns:
Seldon Core deployment service
"""
model_deployer = SeldonModelDeployer.get_active_model_deployer()
# get pipeline name, step name and run id
step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
pipeline_name = step_env.pipeline_name
pipeline_run_id = step_env.pipeline_run_id
step_name = step_env.step_name
# update the step configuration with the real pipeline runtime information
params.service_config.pipeline_name = pipeline_name
params.service_config.pipeline_run_id = pipeline_run_id
params.service_config.pipeline_step_name = step_name
def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig:
"""Prepare the model files for model serving.
This creates and returns a Seldon service configuration for the model.
This function ensures that the model files are in the correct format
and file structure required by the Seldon Core server implementation
used for model serving.
Args:
model_uri: the URI of the model artifact being served
Returns:
The URL to the model ready for serving.
Raises:
RuntimeError: if the model files were not found
"""
served_model_uri = os.path.join(
context.get_output_artifact_uri(), "seldon"
)
fileio.makedirs(served_model_uri)
# TODO [ENG-773]: determine how to formalize how models are organized into
# folders and sub-folders depending on the model type/format and the
# Seldon Core protocol used to serve the model.
# TODO [ENG-791]: auto-detect built-in Seldon server implementation
# from the model artifact type
# TODO [ENG-792]: validate the model artifact type against the
# supported built-in Seldon server implementations
if params.service_config.implementation == "TENSORFLOW_SERVER":
# the TensorFlow server expects model artifacts to be
# stored in numbered subdirectories, each representing a model
# version
io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
elif params.service_config.implementation == "SKLEARN_SERVER":
# the sklearn server expects model artifacts to be
# stored in a file called model.joblib
model_uri = os.path.join(model.uri, "model")
if not fileio.exists(model.uri):
raise RuntimeError(
f"Expected sklearn model artifact was not found at "
f"{model_uri}"
)
fileio.copy(
model_uri, os.path.join(served_model_uri, "model.joblib")
)
else:
# default treatment for all other server implementations is to
# simply reuse the model from the artifact store path where it
# is originally stored
served_model_uri = model_uri
service_config = params.service_config.copy()
service_config.model_uri = served_model_uri
return service_config
# fetch existing services with same pipeline name, step name and
# model name
existing_services = model_deployer.find_model_server(
pipeline_name=pipeline_name,
pipeline_step_name=step_name,
model_name=params.service_config.model_name,
)
# even when the deploy decision is negative, if an existing model server
# is not running for this pipeline/step, we still have to serve the
# current model, to ensure that a model server is available at all times
if not deploy_decision and existing_services:
logger.info(
f"Skipping model deployment because the model quality does not "
f"meet the criteria. Reusing last model server deployed by step "
f"'{step_name}' and pipeline '{pipeline_name}' for model "
f"'{params.service_config.model_name}'..."
)
service = cast(SeldonDeploymentService, existing_services[0])
# even when the deploy decision is negative, we still need to start
# the previous model server if it is no longer running, to ensure that
# a model server is available at all times
if not service.is_running:
service.start(timeout=params.timeout)
return service
# invoke the Seldon Core model deployer to create a new service
# or update an existing one that was previously deployed for the same
# model
service_config = prepare_service_config(model.uri)
service = cast(
SeldonDeploymentService,
model_deployer.deploy_model(
service_config, replace=True, timeout=params.timeout
),
)
logger.info(
f"Seldon deployment service started and reachable at:\n"
f" {service.prediction_url}\n"
)
return service
sklearn
special
Initialization of the sklearn integration.
SklearnIntegration (Integration)
Definition of sklearn integration for ZenML.
Source code in zenml/integrations/sklearn/__init__.py
class SklearnIntegration(Integration):
"""Definition of sklearn integration for ZenML."""
NAME = SKLEARN
REQUIREMENTS = ["scikit-learn"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/sklearn/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa
materializers
special
Initialization of the sklearn materializer.
sklearn_materializer
Implementation of the sklearn materializer.
SklearnMaterializer (BaseMaterializer)
Materializer to read data to and from sklearn.
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
class SklearnMaterializer(BaseMaterializer):
"""Materializer to read data to and from sklearn."""
ASSOCIATED_TYPES = (
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(
self, data_type: Type[Any]
) -> Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
]:
"""Reads a base sklearn model from a pickle file.
Args:
data_type: The type of the model.
Returns:
The model.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "rb") as fid:
clf = pickle.load(fid)
return clf
def handle_return(
self,
clf: Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
],
) -> None:
"""Creates a pickle for a sklearn model.
Args:
clf: A sklearn model.
"""
super().handle_return(clf)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "wb") as fid:
pickle.dump(clf, fid)
handle_input(self, data_type)
Reads a base sklearn model from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model. |
required |
Returns:
Type | Description |
---|---|
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin] |
The model. |
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_input(
self, data_type: Type[Any]
) -> Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
]:
"""Reads a base sklearn model from a pickle file.
Args:
data_type: The type of the model.
Returns:
The model.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "rb") as fid:
clf = pickle.load(fid)
return clf
handle_return(self, clf)
Creates a pickle for a sklearn model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clf |
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin] |
A sklearn model. |
required |
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_return(
self,
clf: Union[
BaseEstimator,
ClassifierMixin,
ClusterMixin,
BiclusterMixin,
OutlierMixin,
RegressorMixin,
MetaEstimatorMixin,
MultiOutputMixin,
DensityMixin,
TransformerMixin,
],
) -> None:
"""Creates a pickle for a sklearn model.
Args:
clf: A sklearn model.
"""
super().handle_return(clf)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
with fileio.open(filepath, "wb") as fid:
pickle.dump(clf, fid)
slack
special
Slack integration for alerter components.
SlackIntegration (Integration)
Definition of a Slack integration for ZenML.
Implemented using Slack SDK.
Source code in zenml/integrations/slack/__init__.py
class SlackIntegration(Integration):
"""Definition of a Slack integration for ZenML.
Implemented using [Slack SDK](https://pypi.org/project/slack-sdk/).
"""
NAME = SLACK
REQUIREMENTS = ["slack-sdk>=3.16.1", "aiohttp>=3.8.1"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Slack integration.
Returns:
List of new flavors defined by the Slack integration.
"""
from zenml.integrations.slack.flavors import SlackAlerterFlavor
return [SlackAlerterFlavor]
flavors()
classmethod
Declare the stack component flavors for the Slack integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of new flavors defined by the Slack integration. |
Source code in zenml/integrations/slack/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Slack integration.
Returns:
List of new flavors defined by the Slack integration.
"""
from zenml.integrations.slack.flavors import SlackAlerterFlavor
return [SlackAlerterFlavor]
alerters
special
Alerter components defined by the Slack integration.
slack_alerter
Implementation for slack flavor of alerter component.
SlackAlerter (BaseAlerter)
Send messages to Slack channels.
Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerter(BaseAlerter):
"""Send messages to Slack channels."""
@property
def config(self) -> SlackAlerterConfig:
"""Returns the `SlackAlerterConfig` config.
Returns:
The configuration.
"""
return cast(SlackAlerterConfig, self._config)
def _get_channel_id(
self, params: Optional[BaseAlerterStepParameters]
) -> str:
"""Get the Slack channel ID to be used by post/ask.
Args:
params: Optional parameters.
Returns:
ID of the Slack channel to be used.
Raises:
RuntimeError: if config is not of type `BaseAlerterStepConfig`.
ValueError: if a slack channel was neither defined in the config
nor in the slack alerter component.
"""
if not isinstance(params, BaseAlerterStepParameters):
raise RuntimeError(
"The config object must be of type `BaseAlerterStepParameters`."
)
if (
isinstance(params, SlackAlerterParameters)
and hasattr(params, "slack_channel_id")
and params.slack_channel_id is not None
):
return params.slack_channel_id
if self.config.default_slack_channel_id is not None:
return self.config.default_slack_channel_id
raise ValueError(
"Neither the `SlackAlerterConfig.slack_channel_id` in the runtime "
"configuration, nor the `default_slack_channel_id` in the alerter "
"stack component is specified. Please specify at least one."
)
def _get_approve_msg_options(
self, params: Optional[BaseAlerterStepParameters]
) -> List[str]:
"""Define which messages will lead to approval during ask().
Args:
params: Optional parameters.
Returns:
Set of messages that lead to approval in alerter.ask().
"""
if (
isinstance(params, SlackAlerterParameters)
and hasattr(params, "approve_msg_options")
and params.approve_msg_options is not None
):
return params.approve_msg_options
return DEFAULT_APPROVE_MSG_OPTIONS
def _get_disapprove_msg_options(
self, params: Optional[BaseAlerterStepParameters]
) -> List[str]:
"""Define which messages will lead to disapproval during ask().
Args:
params: Optional parameters.
Returns:
Set of messages that lead to disapproval in alerter.ask().
"""
if (
isinstance(params, SlackAlerterParameters)
and hasattr(params, "disapprove_msg_options")
and params.disapprove_msg_options is not None
):
return params.disapprove_msg_options
return DEFAULT_DISAPPROVE_MSG_OPTIONS
def post(
self, message: str, params: Optional[BaseAlerterStepParameters]
) -> bool:
"""Post a message to a Slack channel.
Args:
message: Message to be posted.
params: Optional parameters.
Returns:
True if operation succeeded, else False
"""
slack_channel_id = self._get_channel_id(params=params)
client = WebClient(token=self.config.slack_token)
try:
response = client.chat_postMessage(
channel=slack_channel_id,
text=message,
)
return True
except SlackApiError as error:
response = error.response["error"]
logger.error(f"SlackAlerter.post() failed: {response}")
return False
def ask(
self, message: str, params: Optional[BaseAlerterStepParameters]
) -> bool:
"""Post a message to a Slack channel and wait for approval.
Args:
message: Initial message to be posted.
params: Optional parameters.
Returns:
True if a user approved the operation, else False
"""
rtm = RTMClient(token=self.config.slack_token)
slack_channel_id = self._get_channel_id(params=params)
approved = False # will be modified by handle()
@RTMClient.run_on(event="hello") # type: ignore
def post_initial_message(**payload: Any) -> None:
"""Post an initial message in a channel and start listening.
Args:
payload: payload of the received Slack event.
"""
web_client = payload["web_client"]
web_client.chat_postMessage(channel=slack_channel_id, text=message)
@RTMClient.run_on(event="message") # type: ignore
def handle(**payload: Any) -> None:
"""Listen / handle messages posted in the channel.
Args:
payload: payload of the received Slack event.
"""
event = payload["data"]
if event["channel"] == slack_channel_id:
# approve request (return True)
if event["text"] in self._get_approve_msg_options(params):
print(f"User {event['user']} approved on slack.")
nonlocal approved
approved = True
rtm.stop() # type: ignore
# disapprove request (return False)
elif event["text"] in self._get_disapprove_msg_options(params):
print(f"User {event['user']} disapproved on slack.")
rtm.stop() # type:ignore
# start another thread until `rtm.stop()` is called in handle()
rtm.start()
return approved
config: SlackAlerterConfig
property
readonly
Returns the SlackAlerterConfig
config.
Returns:
Type | Description |
---|---|
SlackAlerterConfig |
The configuration. |
ask(self, message, params)
Post a message to a Slack channel and wait for approval.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
message |
str |
Initial message to be posted. |
required |
params |
Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepParameters] |
Optional parameters. |
required |
Returns:
Type | Description |
---|---|
bool |
True if a user approved the operation, else False |
Source code in zenml/integrations/slack/alerters/slack_alerter.py
def ask(
self, message: str, params: Optional[BaseAlerterStepParameters]
) -> bool:
"""Post a message to a Slack channel and wait for approval.
Args:
message: Initial message to be posted.
params: Optional parameters.
Returns:
True if a user approved the operation, else False
"""
rtm = RTMClient(token=self.config.slack_token)
slack_channel_id = self._get_channel_id(params=params)
approved = False # will be modified by handle()
@RTMClient.run_on(event="hello") # type: ignore
def post_initial_message(**payload: Any) -> None:
"""Post an initial message in a channel and start listening.
Args:
payload: payload of the received Slack event.
"""
web_client = payload["web_client"]
web_client.chat_postMessage(channel=slack_channel_id, text=message)
@RTMClient.run_on(event="message") # type: ignore
def handle(**payload: Any) -> None:
"""Listen / handle messages posted in the channel.
Args:
payload: payload of the received Slack event.
"""
event = payload["data"]
if event["channel"] == slack_channel_id:
# approve request (return True)
if event["text"] in self._get_approve_msg_options(params):
print(f"User {event['user']} approved on slack.")
nonlocal approved
approved = True
rtm.stop() # type: ignore
# disapprove request (return False)
elif event["text"] in self._get_disapprove_msg_options(params):
print(f"User {event['user']} disapproved on slack.")
rtm.stop() # type:ignore
# start another thread until `rtm.stop()` is called in handle()
rtm.start()
return approved
post(self, message, params)
Post a message to a Slack channel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
message |
str |
Message to be posted. |
required |
params |
Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepParameters] |
Optional parameters. |
required |
Returns:
Type | Description |
---|---|
bool |
True if operation succeeded, else False |
Source code in zenml/integrations/slack/alerters/slack_alerter.py
def post(
self, message: str, params: Optional[BaseAlerterStepParameters]
) -> bool:
"""Post a message to a Slack channel.
Args:
message: Message to be posted.
params: Optional parameters.
Returns:
True if operation succeeded, else False
"""
slack_channel_id = self._get_channel_id(params=params)
client = WebClient(token=self.config.slack_token)
try:
response = client.chat_postMessage(
channel=slack_channel_id,
text=message,
)
return True
except SlackApiError as error:
response = error.response["error"]
logger.error(f"SlackAlerter.post() failed: {response}")
return False
SlackAlerterParameters (BaseAlerterStepParameters)
pydantic-model
Slack alerter parameters.
Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerterParameters(BaseAlerterStepParameters):
"""Slack alerter parameters."""
# The ID of the Slack channel to use for communication.
slack_channel_id: Optional[str] = None
# Set of messages that lead to approval in alerter.ask()
approve_msg_options: Optional[List[str]] = None
# Set of messages that lead to disapproval in alerter.ask()
disapprove_msg_options: Optional[List[str]] = None
flavors
special
Slack integration flavors.
slack_alerter_flavor
Slack alerter flavor.
SlackAlerterConfig (BaseAlerterConfig)
pydantic-model
Slack alerter config.
Attributes:
Name | Type | Description |
---|---|---|
slack_token |
str |
The Slack token tied to the Slack account to be used. |
default_slack_channel_id |
Optional[str] |
The ID of the Slack channel to use for communication if no channel ID is provided in the step config. |
Source code in zenml/integrations/slack/flavors/slack_alerter_flavor.py
class SlackAlerterConfig(BaseAlerterConfig):
"""Slack alerter config.
Attributes:
slack_token: The Slack token tied to the Slack account to be used.
default_slack_channel_id: The ID of the Slack channel to use for
communication if no channel ID is provided in the step config.
"""
slack_token: str = SecretField()
default_slack_channel_id: Optional[str] = None
SlackAlerterFlavor (BaseAlerterFlavor)
Slack alerter flavor.
Source code in zenml/integrations/slack/flavors/slack_alerter_flavor.py
class SlackAlerterFlavor(BaseAlerterFlavor):
"""Slack alerter flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return SLACK_ALERTER_FLAVOR
@property
def config_class(self) -> Type[SlackAlerterConfig]:
"""Returns `SlackAlerterConfig` config class.
Returns:
The config class.
"""
return SlackAlerterConfig
@property
def implementation_class(self) -> Type["SlackAlerter"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.slack.alerters import SlackAlerter
return SlackAlerter
config_class: Type[zenml.integrations.slack.flavors.slack_alerter_flavor.SlackAlerterConfig]
property
readonly
Returns SlackAlerterConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.slack.flavors.slack_alerter_flavor.SlackAlerterConfig] |
The config class. |
implementation_class: Type[SlackAlerter]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[SlackAlerter] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
steps
special
Built-in steps for the Slack integration.
slack_alerter_ask_step
Step that allows you to send messages to Slack and wait for a response.
slack_alerter_ask_step (BaseStep)
Posts a message to the Slack alerter component and waits for approval.
This can be useful, e.g. to easily get a human in the loop before deploying models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Parameters for the Slack alerter. |
required | |
context |
StepContext of the ZenML repository. |
required | |
message |
Initial message to be posted. |
required |
Returns:
Type | Description |
---|---|
True if a user approved the operation, else False. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If currently active alerter is not a |
PARAMETERS_CLASS (BaseAlerterStepParameters)
pydantic-model
Slack alerter parameters.
Source code in zenml/integrations/slack/steps/slack_alerter_ask_step.py
class SlackAlerterParameters(BaseAlerterStepParameters):
"""Slack alerter parameters."""
# The ID of the Slack channel to use for communication.
slack_channel_id: Optional[str] = None
# Set of messages that lead to approval in alerter.ask()
approve_msg_options: Optional[List[str]] = None
# Set of messages that lead to disapproval in alerter.ask()
disapprove_msg_options: Optional[List[str]] = None
entrypoint(params, context, message)
staticmethod
Posts a message to the Slack alerter component and waits for approval.
This can be useful, e.g. to easily get a human in the loop before deploying models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
SlackAlerterParameters |
Parameters for the Slack alerter. |
required |
context |
StepContext |
StepContext of the ZenML repository. |
required |
message |
str |
Initial message to be posted. |
required |
Returns:
Type | Description |
---|---|
bool |
True if a user approved the operation, else False. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If currently active alerter is not a |
Source code in zenml/integrations/slack/steps/slack_alerter_ask_step.py
@step
def slack_alerter_ask_step(
params: SlackAlerterParameters, context: StepContext, message: str
) -> bool:
"""Posts a message to the Slack alerter component and waits for approval.
This can be useful, e.g. to easily get a human in the loop before
deploying models.
Args:
params: Parameters for the Slack alerter.
context: StepContext of the ZenML repository.
message: Initial message to be posted.
Returns:
True if a user approved the operation, else False.
Raises:
RuntimeError: If currently active alerter is not a `SlackAlerter`.
"""
alerter = get_active_alerter(context)
if not isinstance(alerter, SlackAlerter):
# TODO: potential duplicate code for other components
# -> generalize to `check_component_flavor()` utility function?
raise RuntimeError(
"Step `slack_alerter_ask_step` requires an alerter component of "
"flavor `slack`, but the currently active alerter is of type "
f"{type(alerter)}, which is not a subclass of `SlackAlerter`."
)
return alerter.ask(message, params)
slack_alerter_post_step
Step that allows you to post messages to Slack.
slack_alerter_post_step (BaseStep)
Post a message to the Slack alerter component of the active stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Parameters for the Slack alerter. |
required | |
context |
StepContext of the ZenML repository. |
required | |
message |
Message to be posted. |
required |
Returns:
Type | Description |
---|---|
True if operation succeeded, else False. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If currently active alerter is not a |
PARAMETERS_CLASS (BaseAlerterStepParameters)
pydantic-model
Slack alerter parameters.
Source code in zenml/integrations/slack/steps/slack_alerter_post_step.py
class SlackAlerterParameters(BaseAlerterStepParameters):
"""Slack alerter parameters."""
# The ID of the Slack channel to use for communication.
slack_channel_id: Optional[str] = None
# Set of messages that lead to approval in alerter.ask()
approve_msg_options: Optional[List[str]] = None
# Set of messages that lead to disapproval in alerter.ask()
disapprove_msg_options: Optional[List[str]] = None
entrypoint(params, context, message)
staticmethod
Post a message to the Slack alerter component of the active stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
SlackAlerterParameters |
Parameters for the Slack alerter. |
required |
context |
StepContext |
StepContext of the ZenML repository. |
required |
message |
str |
Message to be posted. |
required |
Returns:
Type | Description |
---|---|
bool |
True if operation succeeded, else False. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If currently active alerter is not a |
Source code in zenml/integrations/slack/steps/slack_alerter_post_step.py
@step
def slack_alerter_post_step(
params: SlackAlerterParameters, context: StepContext, message: str
) -> bool:
"""Post a message to the Slack alerter component of the active stack.
Args:
params: Parameters for the Slack alerter.
context: StepContext of the ZenML repository.
message: Message to be posted.
Returns:
True if operation succeeded, else False.
Raises:
RuntimeError: If currently active alerter is not a `SlackAlerter`.
"""
alerter = get_active_alerter(context)
if not isinstance(alerter, SlackAlerter):
raise RuntimeError(
"Step `slack_alerter_post_step` requires an alerter component of "
"flavor `slack`, but the currently active alerter is of type "
f"{type(alerter)}, which is not a subclass of `SlackAlerter`."
)
return alerter.post(message, params)
spark
special
The Spark integration module to enable distributed processing for steps.
SparkIntegration (Integration)
Definition of Spark integration for ZenML.
Source code in zenml/integrations/spark/__init__.py
class SparkIntegration(Integration):
"""Definition of Spark integration for ZenML."""
NAME = SPARK
REQUIREMENTS = ["pyspark==3.2.1"]
@classmethod
def activate(cls) -> None:
"""Activating the corresponding Spark materializers."""
from zenml.integrations.spark import materializers # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Spark integration.
Returns:
The flavor wrapper for the step operator flavor
"""
from zenml.integrations.spark.flavors import (
KubernetesSparkStepOperatorFlavor,
)
return [KubernetesSparkStepOperatorFlavor]
activate()
classmethod
Activating the corresponding Spark materializers.
Source code in zenml/integrations/spark/__init__.py
@classmethod
def activate(cls) -> None:
"""Activating the corresponding Spark materializers."""
from zenml.integrations.spark import materializers # noqa
flavors()
classmethod
Declare the stack component flavors for the Spark integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
The flavor wrapper for the step operator flavor |
Source code in zenml/integrations/spark/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Spark integration.
Returns:
The flavor wrapper for the step operator flavor
"""
from zenml.integrations.spark.flavors import (
KubernetesSparkStepOperatorFlavor,
)
return [KubernetesSparkStepOperatorFlavor]
flavors
special
Spark integration flavors.
spark_on_kubernetes_step_operator_flavor
Spark on Kubernetes step operator flavor.
KubernetesSparkStepOperatorConfig (SparkStepOperatorConfig)
pydantic-model
Config for the Kubernetes Spark step operator.
Attributes:
Name | Type | Description |
---|---|---|
namespace |
Optional[str] |
the namespace under which the driver and executor pods will run. |
service_account |
Optional[str] |
the service account that will be used by various Spark components (to create and watch the pods). |
Source code in zenml/integrations/spark/flavors/spark_on_kubernetes_step_operator_flavor.py
class KubernetesSparkStepOperatorConfig(SparkStepOperatorConfig):
"""Config for the Kubernetes Spark step operator.
Attributes:
namespace: the namespace under which the driver and executor pods
will run.
service_account: the service account that will be used by various Spark
components (to create and watch the pods).
"""
namespace: Optional[str] = None
service_account: Optional[str] = None
KubernetesSparkStepOperatorFlavor (SparkStepOperatorFlavor)
Flavor for the Kubernetes Spark step operator.
Source code in zenml/integrations/spark/flavors/spark_on_kubernetes_step_operator_flavor.py
class KubernetesSparkStepOperatorFlavor(SparkStepOperatorFlavor):
"""Flavor for the Kubernetes Spark step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return SPARK_KUBERNETES_STEP_OPERATOR
@property
def config_class(self) -> Type[KubernetesSparkStepOperatorConfig]:
"""Returns `KubernetesSparkStepOperatorConfig` config class.
Returns:
The config class.
"""
return KubernetesSparkStepOperatorConfig
@property
def implementation_class(self) -> Type["KubernetesSparkStepOperator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.spark.step_operators import (
KubernetesSparkStepOperator,
)
return KubernetesSparkStepOperator
config_class: Type[zenml.integrations.spark.flavors.spark_on_kubernetes_step_operator_flavor.KubernetesSparkStepOperatorConfig]
property
readonly
Returns KubernetesSparkStepOperatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.spark.flavors.spark_on_kubernetes_step_operator_flavor.KubernetesSparkStepOperatorConfig] |
The config class. |
implementation_class: Type[KubernetesSparkStepOperator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[KubernetesSparkStepOperator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
spark_step_operator_flavor
Spark step operator flavor.
SparkStepOperatorConfig (BaseStepOperatorConfig)
pydantic-model
Spark step operator config.
Attributes:
Name | Type | Description |
---|---|---|
master |
str |
is the master URL for the cluster. You might see different schemes for different cluster managers which are supported by Spark like Mesos, YARN, or Kubernetes. Within the context of this PR, the implementation supports Kubernetes as a cluster manager. |
deploy_mode |
str |
can either be 'cluster' (default) or 'client' and it decides where the driver node of the application will run. |
submit_kwargs |
Optional[Dict[str, Any]] |
is the JSON string of a dict, which will be used to define additional params if required (Spark has quite a lot of different parameters, so including them, all in the step operator was not implemented). |
Source code in zenml/integrations/spark/flavors/spark_step_operator_flavor.py
class SparkStepOperatorConfig(BaseStepOperatorConfig):
"""Spark step operator config.
Attributes:
master: is the master URL for the cluster. You might see different
schemes for different cluster managers which are supported by Spark
like Mesos, YARN, or Kubernetes. Within the context of this PR,
the implementation supports Kubernetes as a cluster manager.
deploy_mode: can either be 'cluster' (default) or 'client' and it
decides where the driver node of the application will run.
submit_kwargs: is the JSON string of a dict, which will be used
to define additional params if required (Spark has quite a
lot of different parameters, so including them, all in the step
operator was not implemented).
"""
master: str
deploy_mode: str = "cluster"
submit_kwargs: Optional[Dict[str, Any]] = None
@validator("submit_kwargs", pre=True)
def _convert_json_string(
cls, value: Union[None, str, Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Converts potential JSON strings passed via the CLI to dictionaries.
Args:
value: The value to convert.
Returns:
The converted value.
Raises:
TypeError: If the value is not a `str`, `Dict` or `None`.
ValueError: If the value is an invalid json string or a json string
that does not decode into a dictionary.
"""
if isinstance(value, str):
try:
dict_ = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid json string '{value}'") from e
if not isinstance(dict_, Dict):
raise ValueError(
f"Json string '{value}' did not decode into a dictionary."
)
return dict_
elif isinstance(value, Dict) or value is None:
return value
else:
raise TypeError(f"{value} is not a json string or a dictionary.")
SparkStepOperatorFlavor (BaseStepOperatorFlavor)
Spark step operator flavor.
Source code in zenml/integrations/spark/flavors/spark_step_operator_flavor.py
class SparkStepOperatorFlavor(BaseStepOperatorFlavor):
"""Spark step operator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return "spark"
@property
def config_class(self) -> Type[SparkStepOperatorConfig]:
"""Returns `SparkStepOperatorConfig` config class.
Returns:
The config class.
"""
return SparkStepOperatorConfig
@property
def implementation_class(self) -> Type["SparkStepOperator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.spark.step_operators.spark_step_operator import (
SparkStepOperator,
)
return SparkStepOperator
config_class: Type[zenml.integrations.spark.flavors.spark_step_operator_flavor.SparkStepOperatorConfig]
property
readonly
Returns SparkStepOperatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.spark.flavors.spark_step_operator_flavor.SparkStepOperatorConfig] |
The config class. |
implementation_class: Type[SparkStepOperator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[SparkStepOperator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
materializers
special
Spark Materializers.
spark_dataframe_materializer
Implementation of the Spark Dataframe Materializer.
SparkDataFrameMaterializer (BaseMaterializer)
Materializer to read/write Spark dataframes.
Source code in zenml/integrations/spark/materializers/spark_dataframe_materializer.py
class SparkDataFrameMaterializer(BaseMaterializer):
"""Materializer to read/write Spark dataframes."""
ASSOCIATED_TYPES = (DataFrame,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> DataFrame:
"""Reads and returns a spark dataframe.
Args:
data_type: The type of the data to read.
Returns:
A loaded spark dataframe.
"""
super().handle_input(data_type)
# Create the Spark session
spark = SparkSession.builder.getOrCreate()
# Read the data
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
return spark.read.parquet(path)
def handle_return(self, df: DataFrame) -> None:
"""Writes a spark dataframe.
Args:
df: A spark dataframe object.
"""
super().handle_return(df)
# Write the dataframe to the artifact store
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
df.write.parquet(path)
handle_input(self, data_type)
Reads and returns a spark dataframe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
DataFrame |
A loaded spark dataframe. |
Source code in zenml/integrations/spark/materializers/spark_dataframe_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataFrame:
"""Reads and returns a spark dataframe.
Args:
data_type: The type of the data to read.
Returns:
A loaded spark dataframe.
"""
super().handle_input(data_type)
# Create the Spark session
spark = SparkSession.builder.getOrCreate()
# Read the data
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
return spark.read.parquet(path)
handle_return(self, df)
Writes a spark dataframe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
DataFrame |
A spark dataframe object. |
required |
Source code in zenml/integrations/spark/materializers/spark_dataframe_materializer.py
def handle_return(self, df: DataFrame) -> None:
"""Writes a spark dataframe.
Args:
df: A spark dataframe object.
"""
super().handle_return(df)
# Write the dataframe to the artifact store
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
df.write.parquet(path)
spark_model_materializer
Implementation of the Spark Model Materializer.
SparkModelMaterializer (BaseMaterializer)
Materializer to read/write Spark models.
Source code in zenml/integrations/spark/materializers/spark_model_materializer.py
class SparkModelMaterializer(BaseMaterializer):
"""Materializer to read/write Spark models."""
ASSOCIATED_TYPES = (Transformer, Estimator, Model)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(
self, model_type: Type[Any]
) -> Union[Transformer, Estimator, Model]: # type: ignore[type-arg]
"""Reads and returns a Spark ML model.
Args:
model_type: The type of the model to read.
Returns:
A loaded spark model.
"""
super().handle_input(model_type)
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
return model_type.load(path) # type: ignore[no-any-return]
def handle_return(
self, model: Union[Transformer, Estimator, Model] # type: ignore[type-arg]
) -> None:
"""Writes a spark model.
Args:
model: A spark model.
"""
super().handle_return(model)
# Write the dataframe to the artifact store
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
model.save(path) # type: ignore[union-attr]
handle_input(self, model_type)
Reads and returns a Spark ML model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_type |
Type[Any] |
The type of the model to read. |
required |
Returns:
Type | Description |
---|---|
Union[pyspark.ml.base.Transformer, pyspark.ml.base.Estimator, pyspark.ml.base.Model] |
A loaded spark model. |
Source code in zenml/integrations/spark/materializers/spark_model_materializer.py
def handle_input(
self, model_type: Type[Any]
) -> Union[Transformer, Estimator, Model]: # type: ignore[type-arg]
"""Reads and returns a Spark ML model.
Args:
model_type: The type of the model to read.
Returns:
A loaded spark model.
"""
super().handle_input(model_type)
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
return model_type.load(path) # type: ignore[no-any-return]
handle_return(self, model)
Writes a spark model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[pyspark.ml.base.Transformer, pyspark.ml.base.Estimator, pyspark.ml.base.Model] |
A spark model. |
required |
Source code in zenml/integrations/spark/materializers/spark_model_materializer.py
def handle_return(
self, model: Union[Transformer, Estimator, Model] # type: ignore[type-arg]
) -> None:
"""Writes a spark model.
Args:
model: A spark model.
"""
super().handle_return(model)
# Write the dataframe to the artifact store
path = os.path.join(self.artifact.uri, DEFAULT_FILEPATH)
model.save(path) # type: ignore[union-attr]
step_operators
special
Spark Step Operators.
kubernetes_step_operator
Implementation of the Kubernetes Spark Step Operator.
KubernetesSparkStepOperator (SparkStepOperator)
Step operator which runs Steps with Spark on Kubernetes.
Source code in zenml/integrations/spark/step_operators/kubernetes_step_operator.py
class KubernetesSparkStepOperator(SparkStepOperator):
"""Step operator which runs Steps with Spark on Kubernetes."""
@property
def config(self) -> KubernetesSparkStepOperatorConfig:
"""Returns the `KubernetesSparkStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(KubernetesSparkStepOperatorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The Spark step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the Spark "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The Spark step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"Spark step operator."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_remote_components,
)
@property
def application_path(self) -> Any:
"""Provides the application path in the corresponding docker image.
Returns:
The path to the application entrypoint within the docker image
"""
return f"local://{DOCKER_IMAGE_WORKDIR}/{ENTRYPOINT_NAME}"
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
FileExistsError: If the entrypoint file already exists.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
entrypoint_path = os.path.join(get_source_root_path(), ENTRYPOINT_NAME)
try:
fileio.copy(LOCAL_ENTRYPOINT, entrypoint_path, overwrite=False)
except OSError:
raise FileExistsError(
f"The Kubernetes Spark step operator needs to copy the step "
f"entrypoint to {entrypoint_path}, however a file with this "
f"path already exists."
)
try:
# Build and push the image
docker_image_builder = PipelineDockerImageBuilder()
image_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
finally:
fileio.remove(entrypoint_path)
for step in steps_to_run:
step.config.extra[SPARK_DOCKER_IMAGE_KEY] = image_digest
def _backend_configuration(
self,
spark_config: SparkConf,
step_config: "StepConfiguration",
) -> None:
"""Configures Spark to run on Kubernetes.
This method will build and push a docker image for the drivers and
executors and adjust the config accordingly.
Args:
spark_config: a SparkConf object which collects all the
configuration parameters
step_config: Configuration of the step to run.
"""
docker_image = step_config.extra[SPARK_DOCKER_IMAGE_KEY]
# Adjust the spark configuration
spark_config.set("spark.kubernetes.container.image", docker_image)
if self.config.namespace:
spark_config.set(
"spark.kubernetes.namespace",
self.config.namespace,
)
if self.config.service_account:
spark_config.set(
"spark.kubernetes.authenticate.driver.serviceAccountName",
self.config.service_account,
)
application_path: Any
property
readonly
Provides the application path in the corresponding docker image.
Returns:
Type | Description |
---|---|
Any |
The path to the application entrypoint within the docker image |
config: KubernetesSparkStepOperatorConfig
property
readonly
Returns the KubernetesSparkStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
KubernetesSparkStepOperatorConfig |
The configuration. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If the entrypoint file already exists. |
Source code in zenml/integrations/spark/step_operators/kubernetes_step_operator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
FileExistsError: If the entrypoint file already exists.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
entrypoint_path = os.path.join(get_source_root_path(), ENTRYPOINT_NAME)
try:
fileio.copy(LOCAL_ENTRYPOINT, entrypoint_path, overwrite=False)
except OSError:
raise FileExistsError(
f"The Kubernetes Spark step operator needs to copy the step "
f"entrypoint to {entrypoint_path}, however a file with this "
f"path already exists."
)
try:
# Build and push the image
docker_image_builder = PipelineDockerImageBuilder()
image_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
finally:
fileio.remove(entrypoint_path)
for step in steps_to_run:
step.config.extra[SPARK_DOCKER_IMAGE_KEY] = image_digest
spark_entrypoint_configuration
Spark step operator entrypoint configuration.
SparkEntrypointConfiguration (StepOperatorEntrypointConfiguration)
Entrypoint configuration for the Spark step operator.
Source code in zenml/integrations/spark/step_operators/spark_entrypoint_configuration.py
class SparkEntrypointConfiguration(StepOperatorEntrypointConfiguration):
"""Entrypoint configuration for the Spark step operator."""
def run(self) -> None:
"""Runs the entrypoint configuration.
This prepends the directory containing the source files to the python
path so that spark can find them.
"""
with source_utils.prepend_python_path([DOCKER_IMAGE_WORKDIR]):
super().run()
run(self)
Runs the entrypoint configuration.
This prepends the directory containing the source files to the python path so that spark can find them.
Source code in zenml/integrations/spark/step_operators/spark_entrypoint_configuration.py
def run(self) -> None:
"""Runs the entrypoint configuration.
This prepends the directory containing the source files to the python
path so that spark can find them.
"""
with source_utils.prepend_python_path([DOCKER_IMAGE_WORKDIR]):
super().run()
spark_step_operator
Implementation of the Spark Step Operator.
SparkStepOperator (BaseStepOperator)
Base class for all Spark-related step operators.
Source code in zenml/integrations/spark/step_operators/spark_step_operator.py
class SparkStepOperator(BaseStepOperator):
"""Base class for all Spark-related step operators."""
@property
def config(self) -> SparkStepOperatorConfig:
"""Returns the `SparkStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(SparkStepOperatorConfig, self._config)
@property
def application_path(self) -> Optional[str]:
"""Optional method for providing the application path.
This is especially critical when using 'spark-submit' as it defines the
path (to the application in the environment where Spark is running)
which is used within the command.
For more information on how to set this property please check:
https://spark.apache.org/docs/latest/submitting-applications.html#advanced-dependency-management
Returns:
The path to the application entrypoint
"""
return None
def _resource_configuration(
self,
spark_config: SparkConf,
resource_settings: "ResourceSettings",
) -> None:
"""Configures Spark to handle the resource settings.
This should serve as the layer between our ResourceSettings
and Spark's own ways of configuring its resources.
Note: This is still work-in-progress. In the future, we would like to
enable much more than executor cores and memory with a dedicated
ResourceSettings object.
Args:
spark_config: a SparkConf object which collects all the
configuration parameters
resource_settings: the resource settings for this step
"""
if resource_settings.cpu_count:
spark_config.set(
"spark.executor.cores",
str(int(resource_settings.cpu_count)),
)
if resource_settings.memory:
# TODO[LOW]: Fix the conversion of the memory unit with a new
# type of resource configuration.
spark_config.set(
"spark.executor.memory",
resource_settings.memory.lower().strip("b"),
)
def _backend_configuration(
self,
spark_config: SparkConf,
step_config: "StepConfiguration",
) -> None:
"""Configures Spark to handle backends like YARN, Mesos or Kubernetes.
Args:
spark_config: a SparkConf object which collects all the
configuration parameters
step_config: Configuration of the step to run.
"""
def _io_configuration(self, spark_config: SparkConf) -> None:
"""Configures Spark to handle different input/output sources.
When you work with the Spark integration, you get materializers
such as SparkDataFrameMaterializer, SparkModelMaterializer. However, in
many cases, these materializer work only if the environment, where
Spark is running, is configured according to the artifact store.
Take s3 as an example. When you want to save a dataframe to an S3
artifact store, you need to provide configuration parameters such as,
'"spark.hadoop.fs.s3.impl=org.apache.hadoop.fs.s3a.S3AFileSystem" to
Spark. This method aims to provide these configuration parameters.
Args:
spark_config: a SparkConf object which collects all the
configuration parameters
Raises:
RuntimeError: when the step operator is being used with an S3
artifact store and the artifact store does not have the
required authentication
"""
# Get active artifact store
client = Client()
artifact_store = client.active_stack.artifact_store
from zenml.integrations.s3 import S3_ARTIFACT_STORE_FLAVOR
# If S3, preconfigure the spark session
if artifact_store.flavor == S3_ARTIFACT_STORE_FLAVOR:
(
key,
secret,
_,
) = artifact_store._get_credentials() # type:ignore[attr-defined]
if key and secret:
spark_config.setAll(
[
("spark.hadoop.fs.s3a.fast.upload", "true"),
(
"spark.hadoop.fs.s3.impl",
"org.apache.hadoop.fs.s3a.S3AFileSystem",
),
(
"spark.hadoop.fs.AbstractFileSystem.s3.impl",
"org.apache.hadoop.fs.s3a.S3A",
),
(
"spark.hadoop.fs.s3a.aws.credentials.provider",
"com.amazonaws.auth.DefaultAWSCredentialsProviderChain",
),
("spark.hadoop.fs.s3a.access.key", f"{key}"),
("spark.hadoop.fs.s3a.secret.key", f"{secret}"),
]
)
else:
raise RuntimeError(
"When you use an Spark step operator with an S3 artifact "
"store, please make sure that your artifact store has"
"defined the required credentials namely the access key "
"and the secret access key."
)
else:
logger.warning(
"In most cases, the Spark step operator requires additional "
"configuration based on the artifact store flavor you are "
"using. That also means, that when you use this step operator "
"with certain artifact store flavor, ZenML can take care of "
"the pre-configuration. However, the artifact store flavor "
f"'{artifact_store.flavor}' featured in this stack is not "
f"known to this step operator and it might require additional "
f"configuration."
)
def _additional_configuration(self, spark_config: SparkConf) -> None:
"""Appends the user-defined configuration parameters.
Args:
spark_config: a SparkConf object which collects all the
configuration parameters
"""
# Add the additional parameters
if self.config.submit_kwargs:
for k, v in self.config.submit_kwargs.items():
spark_config.set(k, v)
def _launch_spark_job(
self, spark_config: SparkConf, entrypoint_command: List[str]
) -> None:
"""Generates and executes a spark-submit command.
Args:
spark_config: a SparkConf object which collects all the
configuration parameters
entrypoint_command: The entrypoint command to run.
Raises:
RuntimeError: if the spark-submit fails
"""
# Base spark-submit command
command = [
f"spark-submit "
f"--master {self.config.master} "
f"--deploy-mode {self.config.deploy_mode}"
]
# Add the configuration parameters
command += [f"--conf {c[0]}={c[1]}" for c in spark_config.getAll()]
# Add the application path
command.append(self.application_path) # type: ignore[arg-type]
# Update the default step operator command to use the spark entrypoint
# configuration
original_args = SparkEntrypointConfiguration._parse_arguments(
entrypoint_command
)
command += SparkEntrypointConfiguration.get_entrypoint_arguments(
**original_args
)
final_command = " ".join(command)
# Execute the spark-submit
process = subprocess.Popen(
final_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
shell=True,
)
stdout, stderr = process.communicate()
if process.returncode != 0:
raise RuntimeError(stderr)
print(stdout)
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on Spark.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
# Start off with an empty configuration
conf = SparkConf()
# Add the resource configuration such as cores, memory.
self._resource_configuration(
spark_config=conf,
resource_settings=info.config.resource_settings,
)
# Add the backend configuration such as namespace, docker images names.
self._backend_configuration(spark_config=conf, step_config=info.config)
# Add the IO configuration for the inputs and the outputs
self._io_configuration(
spark_config=conf,
)
# Add any additional configuration given by the user.
self._additional_configuration(
spark_config=conf,
)
# Generate a spark-submit command given the configuration
self._launch_spark_job(
spark_config=conf,
entrypoint_command=entrypoint_command,
)
application_path: Optional[str]
property
readonly
Optional method for providing the application path.
This is especially critical when using 'spark-submit' as it defines the path (to the application in the environment where Spark is running) which is used within the command.
For more information on how to set this property please check:
https://spark.apache.org/docs/latest/submitting-applications.html#advanced-dependency-management
Returns:
Type | Description |
---|---|
Optional[str] |
The path to the application entrypoint |
config: SparkStepOperatorConfig
property
readonly
Returns the SparkStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
SparkStepOperatorConfig |
The configuration. |
launch(self, info, entrypoint_command)
Launches a step on Spark.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
Source code in zenml/integrations/spark/step_operators/spark_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on Spark.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
# Start off with an empty configuration
conf = SparkConf()
# Add the resource configuration such as cores, memory.
self._resource_configuration(
spark_config=conf,
resource_settings=info.config.resource_settings,
)
# Add the backend configuration such as namespace, docker images names.
self._backend_configuration(spark_config=conf, step_config=info.config)
# Add the IO configuration for the inputs and the outputs
self._io_configuration(
spark_config=conf,
)
# Add any additional configuration given by the user.
self._additional_configuration(
spark_config=conf,
)
# Generate a spark-submit command given the configuration
self._launch_spark_job(
spark_config=conf,
entrypoint_command=entrypoint_command,
)
tekton
special
Initialization of the Tekton integration for ZenML.
The Tekton integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Tekton orchestrator with the CLI tool.
TektonIntegration (Integration)
Definition of Tekton Integration for ZenML.
Source code in zenml/integrations/tekton/__init__.py
class TektonIntegration(Integration):
"""Definition of Tekton Integration for ZenML."""
NAME = TEKTON
REQUIREMENTS = ["kfp-tekton==1.3.1"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Tekton integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.tekton.flavors import TektonOrchestratorFlavor
return [TektonOrchestratorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Tekton integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/tekton/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Tekton integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.tekton.flavors import TektonOrchestratorFlavor
return [TektonOrchestratorFlavor]
flavors
special
Tekton integration flavors.
tekton_orchestrator_flavor
Tekton orchestrator flavor.
TektonOrchestratorConfig (BaseOrchestratorConfig)
pydantic-model
Configuration for the Tekton orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
kubernetes_context |
str |
Name of a kubernetes context to run pipelines in. |
kubernetes_namespace |
str |
Name of the kubernetes namespace in which the pods that run the pipeline steps should be running. |
tekton_ui_port |
int |
A local port to which the Tekton UI will be forwarded. |
skip_ui_daemon_provisioning |
bool |
If |
Source code in zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py
class TektonOrchestratorConfig(BaseOrchestratorConfig):
"""Configuration for the Tekton orchestrator.
Attributes:
kubernetes_context: Name of a kubernetes context to run
pipelines in.
kubernetes_namespace: Name of the kubernetes namespace in which the
pods that run the pipeline steps should be running.
tekton_ui_port: A local port to which the Tekton UI will be forwarded.
skip_ui_daemon_provisioning: If `True`, provisioning the Tekton UI
daemon will be skipped.
"""
kubernetes_context: str
kubernetes_namespace: str = "zenml"
tekton_ui_port: int = DEFAULT_TEKTON_UI_PORT
skip_ui_daemon_provisioning: bool = False
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
TektonOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Tekton orchestrator.
Source code in zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py
class TektonOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Tekton orchestrator."""
@property
def name(self) -> str:
"""Name of the orchestrator flavor.
Returns:
Name of the orchestrator flavor.
"""
return TEKTON_ORCHESTRATOR_FLAVOR
@property
def config_class(self) -> Type[TektonOrchestratorConfig]:
"""Returns `TektonOrchestratorConfig` config class.
Returns:
The config class.
"""
return TektonOrchestratorConfig
@property
def implementation_class(self) -> Type["TektonOrchestrator"]:
"""Implementation class for this flavor.
Returns:
Implementation class for this flavor.
"""
from zenml.integrations.tekton.orchestrators import TektonOrchestrator
return TektonOrchestrator
config_class: Type[zenml.integrations.tekton.flavors.tekton_orchestrator_flavor.TektonOrchestratorConfig]
property
readonly
Returns TektonOrchestratorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.tekton.flavors.tekton_orchestrator_flavor.TektonOrchestratorConfig] |
The config class. |
implementation_class: Type[TektonOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[TektonOrchestrator] |
Implementation class for this flavor. |
name: str
property
readonly
Name of the orchestrator flavor.
Returns:
Type | Description |
---|---|
str |
Name of the orchestrator flavor. |
orchestrators
special
Initialization of the Tekton ZenML orchestrator.
tekton_entrypoint_configuration
Implementation of the Tekton entrypoint configuration.
TektonEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for running steps on Tekton.
Source code in zenml/integrations/tekton/orchestrators/tekton_entrypoint_configuration.py
class TektonEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for running steps on Tekton."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the run name.
"""
return super().get_entrypoint_options() | {RUN_NAME_OPTION}
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the run name.
Returns:
The superclass arguments as well as arguments for the run name.
"""
# Tekton replaces the `$(context.pipelineRun.name)` with the actual
# run name when executing a container. This allows users to re-trigger
# runs on the Tekton UI and uses the new run name for storing
# information in the metadata store.
return super().get_entrypoint_arguments(**kwargs) + [
f"--{RUN_NAME_OPTION}",
"$(context.pipelineRun.name)",
]
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the pipeline run name.
Args:
pipeline_name: The name of the pipeline.
Returns:
The pipeline run name passed as argument to the entrypoint.
"""
return cast(str, self.entrypoint_args[RUN_NAME_OPTION])
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs, must include the run name. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the run name. |
Source code in zenml/integrations/tekton/orchestrators/tekton_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the run name.
Returns:
The superclass arguments as well as arguments for the run name.
"""
# Tekton replaces the `$(context.pipelineRun.name)` with the actual
# run name when executing a container. This allows users to re-trigger
# runs on the Tekton UI and uses the new run name for storing
# information in the metadata store.
return super().get_entrypoint_arguments(**kwargs) + [
f"--{RUN_NAME_OPTION}",
"$(context.pipelineRun.name)",
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the run name. |
Source code in zenml/integrations/tekton/orchestrators/tekton_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the run name.
"""
return super().get_entrypoint_options() | {RUN_NAME_OPTION}
get_run_name(self, pipeline_name)
Returns the pipeline run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The pipeline run name passed as argument to the entrypoint. |
Source code in zenml/integrations/tekton/orchestrators/tekton_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> Optional[str]:
"""Returns the pipeline run name.
Args:
pipeline_name: The name of the pipeline.
Returns:
The pipeline run name passed as argument to the entrypoint.
"""
return cast(str, self.entrypoint_args[RUN_NAME_OPTION])
tekton_orchestrator
Implementation of the Tekton orchestrator.
TektonOrchestrator (BaseOrchestrator)
Orchestrator responsible for running pipelines using Tekton.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
class TektonOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Tekton."""
@property
def config(self) -> TektonOrchestratorConfig:
"""Returns the `TektonOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(TektonOrchestratorConfig, self._config)
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
"""Get the list of configured Kubernetes contexts and the active context.
Returns:
A tuple containing the list of configured Kubernetes contexts and
the active context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException:
return [], None
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
@property
def validator(self) -> Optional[StackValidator]:
"""Ensures a stack with only remote components and a container registry.
Returns:
A `StackValidator` instance.
"""
def _validate(stack: "Stack") -> Tuple[bool, str]:
container_registry = stack.container_registry
# should not happen, because the stack validation takes care of
# this, but just in case
assert container_registry is not None
contexts, _ = self.get_kubernetes_contexts()
if self.config.kubernetes_context not in contexts:
return False, (
f"Could not find a Kubernetes context named "
f"'{self.config.kubernetes_context}' in the local "
f"Kubernetes configuration. Please make sure that the "
f"Kubernetes cluster is running and that the kubeconfig "
f"file is configured correctly. To list all configured "
f"contexts, run:\n\n"
f" `kubectl config get-contexts`\n"
)
# go through all stack components and identify those that
# advertise a local path where they persist information that
# they need to be available when running pipelines.
for stack_component in stack.components.values():
local_path = stack_component.local_path
if local_path is None:
continue
return False, (
f"The Tekton orchestrator is configured to run "
f"pipelines in a remote Kubernetes cluster designated "
f"by the '{self.config.kubernetes_context}' configuration "
f"context, but the '{stack_component.name}' "
f"{stack_component.type.value} is a local stack component "
f"and will not be available in the Tekton pipeline "
f"step.\nPlease ensure that you always use non-local "
f"stack components with a Tekton orchestrator, "
f"otherwise you may run into pipeline execution "
f"problems. You should use a flavor of "
f"{stack_component.type.value} other than "
f"'{stack_component.flavor}'."
)
if container_registry.config.is_local:
return False, (
f"The Tekton orchestrator is configured to run "
f"pipelines in a remote Kubernetes cluster designated "
f"by the '{self.config.kubernetes_context}' configuration "
f"context, but the '{container_registry.name}' "
f"container registry URI '{container_registry.config.uri}' "
f"points to a local container registry. Please ensure "
f"that you always use non-local stack components with "
f"a Tekton orchestrator, otherwise you will "
f"run into problems. You should use a flavor of "
f"container registry other than "
f"'{container_registry.flavor}'."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
@staticmethod
def _configure_container_resources(
container_op: dsl.ContainerOp,
resource_settings: "ResourceSettings",
) -> None:
"""Adds resource requirements to the container.
Args:
container_op: The container operation to configure.
resource_settings: The resource settings to use for this
container.
"""
if resource_settings.cpu_count is not None:
container_op = container_op.set_cpu_limit(
str(resource_settings.cpu_count)
)
if resource_settings.gpu_count is not None:
container_op = container_op.set_gpu_limit(
resource_settings.gpu_count
)
if resource_settings.memory is not None:
memory_limit = resource_settings.memory[:-1]
container_op = container_op.set_memory_limit(memory_limit)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Runs the pipeline on Tekton.
This function first compiles the ZenML pipeline into a Tekton yaml
and then applies this configuration to run the pipeline.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
RuntimeError: If you try to run the pipelines in a notebook environment.
"""
# First check whether the code running in a notebook
if Environment.in_notebook():
raise RuntimeError(
"The Tekton orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
def _construct_kfp_pipeline() -> None:
"""Create a container_op for each step.
This should contain the name of the docker image and configures the
entrypoint of the docker image to run the step.
Additionally, this gives each container_op information about its
direct downstream steps.
"""
# Dictionary of container_ops index by the associated step name
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.steps.items():
command = TektonEntrypointConfiguration.get_entrypoint_command()
arguments = (
TektonEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
)
)
container_op = dsl.ContainerOp(
name=step.config.name,
image=image_name,
command=command,
arguments=arguments,
)
if self.requires_resources_in_orchestration_environment(step):
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
# Update dictionary of container ops with the current one
step_name_to_container_op[step.config.name] = container_op
# Get a filepath to use to save the finished yaml to
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{deployment.run_name}.yaml"
)
# Set the run name, which Tekton reads from this attribute of the
# pipeline function
setattr(
_construct_kfp_pipeline,
"_component_human_name",
deployment.run_name,
)
TektonCompiler().compile(_construct_kfp_pipeline, pipeline_file_path)
if deployment.schedule:
logger.warning(
"The Tekton Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
logger.info(
"Running Tekton pipeline in kubernetes context '%s' and namespace "
"'%s'.",
self.config.kubernetes_context,
self.config.kubernetes_namespace,
)
try:
subprocess.check_call(
[
"kubectl",
"--context",
self.config.kubernetes_context,
"--namespace",
self.config.kubernetes_namespace,
"apply",
"-f",
pipeline_file_path,
]
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Failed to upload Tekton pipeline: {str(e)}. "
f"Please make sure your kubernetes config is present and the "
f"{self.config.kubernetes_context} kubernetes context is "
f"configured correctly.",
)
@property
def root_directory(self) -> str:
"""Returns path to the root directory for all files concerning this orchestrator.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"tekton",
str(self.id),
)
@property
def pipeline_directory(self) -> str:
"""Path to a directory in which the Tekton pipeline files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "tekton_daemon.pid")
@property
def log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path of the daemon log file.
"""
return os.path.join(self.root_directory, "tekton_daemon.log")
@property
def is_provisioned(self) -> bool:
"""Returns if a local k3d cluster for this orchestrator exists.
Returns:
True if a local k3d cluster exists, False otherwise.
"""
return fileio.exists(self.root_directory)
@property
def is_running(self) -> bool:
"""Checks if the local UI daemon is running.
Returns:
True if the local UI daemon for this orchestrator is running.
"""
if self.config.skip_ui_daemon_provisioning:
return True
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
return check_if_daemon_is_running(self._pid_file_path)
else:
return True
def provision(self) -> None:
"""Provisions resources for the orchestrator."""
fileio.makedirs(self.root_directory)
def deprovision(self) -> None:
"""Deprovisions the orchestrator resources."""
if self.is_running:
self.suspend()
if fileio.exists(self.log_file):
fileio.remove(self.log_file)
def resume(self) -> None:
"""Starts the UI forwarding daemon if necessary."""
if self.is_running:
logger.info("Tekton UI forwarding is already running.")
return
self.start_ui_daemon()
def suspend(self) -> None:
"""Stops the UI forwarding daemon if it's running."""
if not self.is_running:
logger.info("Tekton UI forwarding not running.")
return
self.stop_ui_daemon()
def start_ui_daemon(self) -> None:
"""Starts the UI forwarding daemon if possible."""
port = self.config.tekton_ui_port
if (
port == DEFAULT_TEKTON_UI_PORT
and not networking_utils.port_available(port)
):
# if the user didn't specify a specific port and the default
# port is occupied, fallback to a random open port
port = networking_utils.find_available_port()
command = [
"kubectl",
"--context",
self.config.kubernetes_context,
"--namespace",
"tekton-pipelines",
"port-forward",
"svc/tekton-dashboard",
f"{port}:9097",
]
if not networking_utils.port_available(port):
modified_command = command.copy()
modified_command[-1] = "<PORT>:9097"
logger.warning(
"Unable to port-forward Tekton UI to local port %d "
"because the port is occupied. In order to access the Tekton "
"UI at http://localhost:<PORT>/, please run '%s' in a "
"separate command line shell (replace <PORT> with a free port "
"of your choice).",
port,
" ".join(modified_command),
)
elif sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Tekton UI at "
"http://localhost:%d/, please run '%s' in a separate command "
"line shell.",
port,
" ".join(command),
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Port-forwards the Tekton UI pod."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self.log_file,
)
logger.info(
"Started Tekton UI daemon (check the daemon logs at %s "
"in case you're not able to view the UI). The Tekton "
"UI should now be accessible at http://localhost:%d/.",
self.log_file,
port,
)
def stop_ui_daemon(self) -> None:
"""Stops the UI forwarding daemon if it's running."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
logger.info("Stopped Tektion UI daemon.")
config: TektonOrchestratorConfig
property
readonly
Returns the TektonOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
TektonOrchestratorConfig |
The configuration. |
is_provisioned: bool
property
readonly
Returns if a local k3d cluster for this orchestrator exists.
Returns:
Type | Description |
---|---|
bool |
True if a local k3d cluster exists, False otherwise. |
is_running: bool
property
readonly
Checks if the local UI daemon is running.
Returns:
Type | Description |
---|---|
bool |
True if the local UI daemon for this orchestrator is running. |
log_file: str
property
readonly
Path of the daemon log file.
Returns:
Type | Description |
---|---|
str |
Path of the daemon log file. |
pipeline_directory: str
property
readonly
Path to a directory in which the Tekton pipeline files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Returns path to the root directory for all files concerning this orchestrator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Ensures a stack with only remote components and a container registry.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
deprovision(self)
Deprovisions the orchestrator resources.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def deprovision(self) -> None:
"""Deprovisions the orchestrator resources."""
if self.is_running:
self.suspend()
if fileio.exists(self.log_file):
fileio.remove(self.log_file)
get_kubernetes_contexts(self)
Get the list of configured Kubernetes contexts and the active context.
Returns:
Type | Description |
---|---|
Tuple[List[str], Optional[str]] |
A tuple containing the list of configured Kubernetes contexts and the active context. |
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]:
"""Get the list of configured Kubernetes contexts and the active context.
Returns:
A tuple containing the list of configured Kubernetes contexts and
the active context.
"""
try:
contexts, active_context = k8s_config.list_kube_config_contexts()
except k8s_config.config_exception.ConfigException:
return [], None
context_names = [c["name"] for c in contexts]
active_context_name = active_context["name"]
return context_names, active_context_name
prepare_or_run_pipeline(self, deployment, stack)
Runs the pipeline on Tekton.
This function first compiles the ZenML pipeline into a Tekton yaml and then applies this configuration to run the pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If you try to run the pipelines in a notebook environment. |
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> Any:
"""Runs the pipeline on Tekton.
This function first compiles the ZenML pipeline into a Tekton yaml
and then applies this configuration to run the pipeline.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
Raises:
RuntimeError: If you try to run the pipelines in a notebook environment.
"""
# First check whether the code running in a notebook
if Environment.in_notebook():
raise RuntimeError(
"The Tekton orchestrator cannot run pipelines in a notebook "
"environment. The reason is that it is non-trivial to create "
"a Docker image of a notebook. Please consider refactoring "
"your notebook cells into separate scripts in a Python module "
"and run the code outside of a notebook when using this "
"orchestrator."
)
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
def _construct_kfp_pipeline() -> None:
"""Create a container_op for each step.
This should contain the name of the docker image and configures the
entrypoint of the docker image to run the step.
Additionally, this gives each container_op information about its
direct downstream steps.
"""
# Dictionary of container_ops index by the associated step name
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.steps.items():
command = TektonEntrypointConfiguration.get_entrypoint_command()
arguments = (
TektonEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
)
)
container_op = dsl.ContainerOp(
name=step.config.name,
image=image_name,
command=command,
arguments=arguments,
)
if self.requires_resources_in_orchestration_environment(step):
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
# Update dictionary of container ops with the current one
step_name_to_container_op[step.config.name] = container_op
# Get a filepath to use to save the finished yaml to
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{deployment.run_name}.yaml"
)
# Set the run name, which Tekton reads from this attribute of the
# pipeline function
setattr(
_construct_kfp_pipeline,
"_component_human_name",
deployment.run_name,
)
TektonCompiler().compile(_construct_kfp_pipeline, pipeline_file_path)
if deployment.schedule:
logger.warning(
"The Tekton Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
logger.info(
"Running Tekton pipeline in kubernetes context '%s' and namespace "
"'%s'.",
self.config.kubernetes_context,
self.config.kubernetes_namespace,
)
try:
subprocess.check_call(
[
"kubectl",
"--context",
self.config.kubernetes_context,
"--namespace",
self.config.kubernetes_namespace,
"apply",
"-f",
pipeline_file_path,
]
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Failed to upload Tekton pipeline: {str(e)}. "
f"Please make sure your kubernetes config is present and the "
f"{self.config.kubernetes_context} kubernetes context is "
f"configured correctly.",
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_and_push_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
provision(self)
Provisions resources for the orchestrator.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def provision(self) -> None:
"""Provisions resources for the orchestrator."""
fileio.makedirs(self.root_directory)
resume(self)
Starts the UI forwarding daemon if necessary.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def resume(self) -> None:
"""Starts the UI forwarding daemon if necessary."""
if self.is_running:
logger.info("Tekton UI forwarding is already running.")
return
self.start_ui_daemon()
start_ui_daemon(self)
Starts the UI forwarding daemon if possible.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def start_ui_daemon(self) -> None:
"""Starts the UI forwarding daemon if possible."""
port = self.config.tekton_ui_port
if (
port == DEFAULT_TEKTON_UI_PORT
and not networking_utils.port_available(port)
):
# if the user didn't specify a specific port and the default
# port is occupied, fallback to a random open port
port = networking_utils.find_available_port()
command = [
"kubectl",
"--context",
self.config.kubernetes_context,
"--namespace",
"tekton-pipelines",
"port-forward",
"svc/tekton-dashboard",
f"{port}:9097",
]
if not networking_utils.port_available(port):
modified_command = command.copy()
modified_command[-1] = "<PORT>:9097"
logger.warning(
"Unable to port-forward Tekton UI to local port %d "
"because the port is occupied. In order to access the Tekton "
"UI at http://localhost:<PORT>/, please run '%s' in a "
"separate command line shell (replace <PORT> with a free port "
"of your choice).",
port,
" ".join(modified_command),
)
elif sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Tekton UI at "
"http://localhost:%d/, please run '%s' in a separate command "
"line shell.",
port,
" ".join(command),
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Port-forwards the Tekton UI pod."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self.log_file,
)
logger.info(
"Started Tekton UI daemon (check the daemon logs at %s "
"in case you're not able to view the UI). The Tekton "
"UI should now be accessible at http://localhost:%d/.",
self.log_file,
port,
)
stop_ui_daemon(self)
Stops the UI forwarding daemon if it's running.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def stop_ui_daemon(self) -> None:
"""Stops the UI forwarding daemon if it's running."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
logger.info("Stopped Tektion UI daemon.")
suspend(self)
Stops the UI forwarding daemon if it's running.
Source code in zenml/integrations/tekton/orchestrators/tekton_orchestrator.py
def suspend(self) -> None:
"""Stops the UI forwarding daemon if it's running."""
if not self.is_running:
logger.info("Tekton UI forwarding not running.")
return
self.stop_ui_daemon()
tensorboard
special
Initialization for TensorBoard integration.
TensorBoardIntegration (Integration)
Definition of TensorBoard integration for ZenML.
Source code in zenml/integrations/tensorboard/__init__.py
class TensorBoardIntegration(Integration):
"""Definition of TensorBoard integration for ZenML."""
NAME = TENSORBOARD
REQUIREMENTS = ["tensorboard==2.8.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.tensorboard import services # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/tensorboard/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.tensorboard import services # noqa
services
special
Initialization for TensorBoard services.
tensorboard_service
Implementation of the TensorBoard service.
TensorboardService (LocalDaemonService)
pydantic-model
TensorBoard service.
This can be used to start a local TensorBoard server for one or more models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the TensorBoard service class |
config |
TensorboardServiceConfig |
service configuration |
endpoint |
LocalDaemonServiceEndpoint |
optional service endpoint |
Source code in zenml/integrations/tensorboard/services/tensorboard_service.py
class TensorboardService(LocalDaemonService):
"""TensorBoard service.
This can be used to start a local TensorBoard server for one or more models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the TensorBoard service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="tensorboard",
type="visualization",
flavor="tensorboard",
description="TensorBoard visualization service",
)
config: TensorboardServiceConfig
endpoint: LocalDaemonServiceEndpoint
def __init__(
self,
config: Union[TensorboardServiceConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialization for TensorBoard service.
Args:
config: service configuration
**attrs: additional attributes
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-697]: implement a service factory or builder for TensorBoard
# deployment services
if (
isinstance(config, TensorboardServiceConfig)
and "endpoint" not in attrs
):
endpoint = LocalDaemonServiceEndpoint(
config=LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path="",
use_head_request=True,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
"""Initialize and run the TensorBoard server."""
logger.info(
"Starting TensorBoard service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
tensorboard = program.TensorBoard(
plugins=default.get_plugins(),
subcommands=[uploader_subcommand.UploaderSubcommand()],
)
tensorboard.configure(
logdir=self.config.logdir,
port=self.endpoint.status.port,
host="localhost",
max_reload_threads=self.config.max_reload_threads,
reload_interval=self.config.reload_interval,
)
tensorboard.main()
except KeyboardInterrupt:
logger.info(
"TensorBoard service stopped. Resuming normal execution."
)
__init__(self, config, **attrs)
special
Initialization for TensorBoard service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[zenml.integrations.tensorboard.services.tensorboard_service.TensorboardServiceConfig, Dict[str, Any]] |
service configuration |
required |
**attrs |
Any |
additional attributes |
{} |
Source code in zenml/integrations/tensorboard/services/tensorboard_service.py
def __init__(
self,
config: Union[TensorboardServiceConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialization for TensorBoard service.
Args:
config: service configuration
**attrs: additional attributes
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-697]: implement a service factory or builder for TensorBoard
# deployment services
if (
isinstance(config, TensorboardServiceConfig)
and "endpoint" not in attrs
):
endpoint = LocalDaemonServiceEndpoint(
config=LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path="",
use_head_request=True,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
run(self)
Initialize and run the TensorBoard server.
Source code in zenml/integrations/tensorboard/services/tensorboard_service.py
def run(self) -> None:
"""Initialize and run the TensorBoard server."""
logger.info(
"Starting TensorBoard service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
tensorboard = program.TensorBoard(
plugins=default.get_plugins(),
subcommands=[uploader_subcommand.UploaderSubcommand()],
)
tensorboard.configure(
logdir=self.config.logdir,
port=self.endpoint.status.port,
host="localhost",
max_reload_threads=self.config.max_reload_threads,
reload_interval=self.config.reload_interval,
)
tensorboard.main()
except KeyboardInterrupt:
logger.info(
"TensorBoard service stopped. Resuming normal execution."
)
TensorboardServiceConfig (LocalDaemonServiceConfig)
pydantic-model
TensorBoard service configuration.
Attributes:
Name | Type | Description |
---|---|---|
logdir |
str |
location of TensorBoard log files. |
max_reload_threads |
int |
the max number of threads that TensorBoard can use to reload runs. Each thread reloads one run at a time. |
reload_interval |
int |
how often the backend should load more data, in seconds. Set to 0 to load just once at startup. |
Source code in zenml/integrations/tensorboard/services/tensorboard_service.py
class TensorboardServiceConfig(LocalDaemonServiceConfig):
"""TensorBoard service configuration.
Attributes:
logdir: location of TensorBoard log files.
max_reload_threads: the max number of threads that TensorBoard can use
to reload runs. Each thread reloads one run at a time.
reload_interval: how often the backend should load more data, in
seconds. Set to 0 to load just once at startup.
"""
logdir: str
max_reload_threads: int = 1
reload_interval: int = 5
visualizers
special
Initialization for TensorBoard visualizer.
tensorboard_visualizer
Implementation of a TensorBoard visualizer step.
TensorboardVisualizer (BaseVisualizer)
The implementation of a TensorBoard Visualizer.
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
class TensorboardVisualizer(BaseVisualizer):
"""The implementation of a TensorBoard Visualizer."""
@classmethod
def find_running_tensorboard_server(
cls, logdir: str
) -> Optional[TensorBoardInfo]:
"""Find a local TensorBoard server instance.
Finds when it is running for the supplied logdir location and return its
TCP port.
Args:
logdir: The logdir location where the TensorBoard server is running.
Returns:
The TensorBoardInfo describing the running TensorBoard server or
None if no server is running for the supplied logdir location.
"""
for server in get_all():
if (
server.logdir == logdir
and server.pid
and psutil.pid_exists(server.pid)
):
return server
return None
def visualize(
self,
object: StepView,
height: int = 800,
*args: Any,
**kwargs: Any,
) -> None:
"""Start a TensorBoard server.
Allows for the visualization of all models logged as artifacts by the
indicated step. The server will monitor and display all the models
logged by past and future step runs.
Args:
object: StepView fetched from run.get_step().
height: Height of the generated visualization.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return
if sys.platform == "win32":
# Daemon service functionality is currently not supported on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
else:
# start a new TensorBoard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
)
)
service.start(timeout=20)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return
def visualize_tensorboard(
self,
port: int,
height: int,
) -> None:
"""Generate a visualization of a TensorBoard.
Args:
port: the TCP port where the TensorBoard server is listening for
requests.
height: Height of the generated visualization.
"""
if Environment.in_notebook():
notebook.display(port, height=height)
return
print(
"You can visit:\n"
f"[italic green] http://localhost:{port}/[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
def stop(
self,
object: StepView,
) -> None:
"""Stop the TensorBoard server previously started for a pipeline step.
Args:
object: StepView fetched from run.get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
find_running_tensorboard_server(logdir)
classmethod
Find a local TensorBoard server instance.
Finds when it is running for the supplied logdir location and return its TCP port.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logdir |
str |
The logdir location where the TensorBoard server is running. |
required |
Returns:
Type | Description |
---|---|
Optional[collections.TensorBoardInfo] |
The TensorBoardInfo describing the running TensorBoard server or None if no server is running for the supplied logdir location. |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
@classmethod
def find_running_tensorboard_server(
cls, logdir: str
) -> Optional[TensorBoardInfo]:
"""Find a local TensorBoard server instance.
Finds when it is running for the supplied logdir location and return its
TCP port.
Args:
logdir: The logdir location where the TensorBoard server is running.
Returns:
The TensorBoardInfo describing the running TensorBoard server or
None if no server is running for the supplied logdir location.
"""
for server in get_all():
if (
server.logdir == logdir
and server.pid
and psutil.pid_exists(server.pid)
):
return server
return None
stop(self, object)
Stop the TensorBoard server previously started for a pipeline step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
def stop(
self,
object: StepView,
) -> None:
"""Stop the TensorBoard server previously started for a pipeline step.
Args:
object: StepView fetched from run.get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
visualize(self, object, height=800, *args, **kwargs)
Start a TensorBoard server.
Allows for the visualization of all models logged as artifacts by the indicated step. The server will monitor and display all the models logged by past and future step runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
height |
int |
Height of the generated visualization. |
800 |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
def visualize(
self,
object: StepView,
height: int = 800,
*args: Any,
**kwargs: Any,
) -> None:
"""Start a TensorBoard server.
Allows for the visualization of all models logged as artifacts by the
indicated step. The server will monitor and display all the models
logged by past and future step runs.
Args:
object: StepView fetched from run.get_step().
height: Height of the generated visualization.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ModelArtifact.TYPE_NAME:
logdir = os.path.dirname(artifact_view.uri)
# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return
if sys.platform == "win32":
# Daemon service functionality is currently not supported on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
else:
# start a new TensorBoard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
)
)
service.start(timeout=20)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return
visualize_tensorboard(self, port, height)
Generate a visualization of a TensorBoard.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
port |
int |
the TCP port where the TensorBoard server is listening for requests. |
required |
height |
int |
Height of the generated visualization. |
required |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(
self,
port: int,
height: int,
) -> None:
"""Generate a visualization of a TensorBoard.
Args:
port: the TCP port where the TensorBoard server is listening for
requests.
height: Height of the generated visualization.
"""
if Environment.in_notebook():
notebook.display(port, height=height)
return
print(
"You can visit:\n"
f"[italic green] http://localhost:{port}/[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
get_step(pipeline_name, step_name)
Get the StepView for the specified pipeline and step name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline. |
required |
step_name |
str |
The name of the step. |
required |
Returns:
Type | Description |
---|---|
StepView |
The StepView for the specified pipeline and step name. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the step is not found. |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
def get_step(pipeline_name: str, step_name: str) -> StepView:
"""Get the StepView for the specified pipeline and step name.
Args:
pipeline_name: The name of the pipeline.
step_name: The name of the step.
Returns:
The StepView for the specified pipeline and step name.
Raises:
RuntimeError: If the step is not found.
"""
pipeline = get_pipeline(pipeline_name)
if pipeline is None:
raise RuntimeError(f"No pipeline with name `{pipeline_name}` was found")
last_run = pipeline.runs[-1]
step = last_run.get_step(step=step_name)
if step is None:
raise RuntimeError(
f"No pipeline step with name `{step_name}` was found in "
f"pipeline `{pipeline_name}`"
)
return cast(StepView, step)
stop_tensorboard_server(pipeline_name, step_name)
Stop the TensorBoard server previously started for a pipeline step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the pipeline |
required |
step_name |
str |
pipeline step name |
required |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
def stop_tensorboard_server(pipeline_name: str, step_name: str) -> None:
"""Stop the TensorBoard server previously started for a pipeline step.
Args:
pipeline_name: the name of the pipeline
step_name: pipeline step name
"""
step = get_step(pipeline_name, step_name)
TensorboardVisualizer().stop(step)
visualize_tensorboard(pipeline_name, step_name)
Start a TensorBoard server.
Allows for the visualization of all models logged as output by the named pipeline step. The server will monitor and display all the models logged by past and future step runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the pipeline |
required |
step_name |
str |
pipeline step name |
required |
Source code in zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(pipeline_name: str, step_name: str) -> None:
"""Start a TensorBoard server.
Allows for the visualization of all models logged as output by the named
pipeline step. The server will monitor and display all the models logged by
past and future step runs.
Args:
pipeline_name: the name of the pipeline
step_name: pipeline step name
"""
step = get_step(pipeline_name, step_name)
TensorboardVisualizer().visualize(step)
tensorflow
special
Initialization for TensorFlow integration.
TensorflowIntegration (Integration)
Definition of Tensorflow integration for ZenML.
Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
"""Definition of Tensorflow integration for ZenML."""
NAME = TENSORFLOW
REQUIREMENTS = ["tensorflow==2.8.0", "tensorflow_io==0.24.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
# need to import this explicitly to load the Tensorflow file IO support
# for S3 and other file systems
import tensorflow_io # type: ignore [import]
from zenml.integrations.tensorflow import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
# need to import this explicitly to load the Tensorflow file IO support
# for S3 and other file systems
import tensorflow_io # type: ignore [import]
from zenml.integrations.tensorflow import materializers # noqa
materializers
special
Initialization for the TensorFlow materializers.
keras_materializer
Implementation of the TensorFlow Keras materializer.
KerasMaterializer (BaseMaterializer)
Materializer to read/write Keras models.
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
"""Materializer to read/write Keras models."""
ASSOCIATED_TYPES = (keras.Model,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> keras.Model:
"""Reads and returns a Keras model after copying it to temporary path.
Args:
data_type: The type of the data to read.
Returns:
A tf.keras.Model model.
"""
super().handle_input(data_type)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
# Copy from artifact store to temporary directory
io_utils.copy_dir(self.artifact.uri, temp_dir.name)
# Load the model from the temporary directory
model = keras.models.load_model(temp_dir.name)
# Cleanup and return
fileio.rmtree(temp_dir.name)
return model
def handle_return(self, model: keras.Model) -> None:
"""Writes a keras model to the artifact store.
Args:
model: A tf.keras.Model model.
"""
super().handle_return(model)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
model.save(temp_dir.name)
io_utils.copy_dir(temp_dir.name, self.artifact.uri)
# Remove the temporary directory
fileio.rmtree(temp_dir.name)
handle_input(self, data_type)
Reads and returns a Keras model after copying it to temporary path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Model |
A tf.keras.Model model. |
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_input(self, data_type: Type[Any]) -> keras.Model:
"""Reads and returns a Keras model after copying it to temporary path.
Args:
data_type: The type of the data to read.
Returns:
A tf.keras.Model model.
"""
super().handle_input(data_type)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
# Copy from artifact store to temporary directory
io_utils.copy_dir(self.artifact.uri, temp_dir.name)
# Load the model from the temporary directory
model = keras.models.load_model(temp_dir.name)
# Cleanup and return
fileio.rmtree(temp_dir.name)
return model
handle_return(self, model)
Writes a keras model to the artifact store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Model |
A tf.keras.Model model. |
required |
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_return(self, model: keras.Model) -> None:
"""Writes a keras model to the artifact store.
Args:
model: A tf.keras.Model model.
"""
super().handle_return(model)
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()
model.save(temp_dir.name)
io_utils.copy_dir(temp_dir.name, self.artifact.uri)
# Remove the temporary directory
fileio.rmtree(temp_dir.name)
tf_dataset_materializer
Implementation of the TensorFlow dataset materializer.
TensorflowDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from tf.data.Dataset.
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from tf.data.Dataset."""
ASSOCIATED_TYPES = (tf.data.Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Any:
"""Reads data into tf.data.Dataset.
Args:
data_type: The type of the data to read.
Returns:
A tf.data.Dataset object.
"""
super().handle_input(data_type)
temp_dir = tempfile.mkdtemp()
io_utils.copy_dir(self.artifact.uri, temp_dir)
path = os.path.join(temp_dir, DEFAULT_FILENAME)
dataset = tf.data.experimental.load(path)
# Don't delete the temporary directory here as the dataset is lazily
# loaded and needs to read it when the object gets used
return dataset
def handle_return(self, dataset: tf.data.Dataset) -> None:
"""Persists a tf.data.Dataset object.
Args:
dataset: The dataset to persist.
"""
super().handle_return(dataset)
temp_dir = tempfile.TemporaryDirectory()
path = os.path.join(temp_dir.name, DEFAULT_FILENAME)
try:
tf.data.experimental.save(
dataset, path, compression=None, shard_func=None
)
io_utils.copy_dir(temp_dir.name, self.artifact.uri)
finally:
fileio.rmtree(temp_dir.name)
handle_input(self, data_type)
Reads data into tf.data.Dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Any |
A tf.data.Dataset object. |
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Any:
"""Reads data into tf.data.Dataset.
Args:
data_type: The type of the data to read.
Returns:
A tf.data.Dataset object.
"""
super().handle_input(data_type)
temp_dir = tempfile.mkdtemp()
io_utils.copy_dir(self.artifact.uri, temp_dir)
path = os.path.join(temp_dir, DEFAULT_FILENAME)
dataset = tf.data.experimental.load(path)
# Don't delete the temporary directory here as the dataset is lazily
# loaded and needs to read it when the object gets used
return dataset
handle_return(self, dataset)
Persists a tf.data.Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DatasetV2 |
The dataset to persist. |
required |
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_return(self, dataset: tf.data.Dataset) -> None:
"""Persists a tf.data.Dataset object.
Args:
dataset: The dataset to persist.
"""
super().handle_return(dataset)
temp_dir = tempfile.TemporaryDirectory()
path = os.path.join(temp_dir.name, DEFAULT_FILENAME)
try:
tf.data.experimental.save(
dataset, path, compression=None, shard_func=None
)
io_utils.copy_dir(temp_dir.name, self.artifact.uri)
finally:
fileio.rmtree(temp_dir.name)
utils
Utility functions for the integrations module.
get_integration_for_module(module_name)
Gets the integration class for a module inside an integration.
If the module given by module_name
is not part of a ZenML integration,
this method will return None
. If it is part of a ZenML integration,
it will return the integration class found inside the integration
init file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name |
str |
The name of the module to get the integration for. |
required |
Returns:
Type | Description |
---|---|
Optional[Type[zenml.integrations.integration.Integration]] |
The integration class for the module. |
Source code in zenml/integrations/utils.py
def get_integration_for_module(
module_name: str,
) -> Optional[Type[Integration]]:
"""Gets the integration class for a module inside an integration.
If the module given by `module_name` is not part of a ZenML integration,
this method will return `None`. If it is part of a ZenML integration,
it will return the integration class found inside the integration
__init__ file.
Args:
module_name: The name of the module to get the integration for.
Returns:
The integration class for the module.
"""
integration_prefix = "zenml.integrations."
if not module_name.startswith(integration_prefix):
return None
integration_module_name = ".".join(module_name.split(".", 3)[:3])
try:
integration_module = sys.modules[integration_module_name]
except KeyError:
integration_module = importlib.import_module(integration_module_name)
for name, member in inspect.getmembers(integration_module):
if (
member is not Integration
and isinstance(member, IntegrationMeta)
and issubclass(member, Integration)
):
return cast(Type[Integration], member)
return None
get_requirements_for_module(module_name)
Gets requirements for a module inside an integration.
If the module given by module_name
is not part of a ZenML integration,
this method will return an empty list. If it is part of a ZenML integration,
it will return the list of requirements specified inside the integration
class found inside the integration init file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name |
str |
The name of the module to get requirements for. |
required |
Returns:
Type | Description |
---|---|
List[str] |
A list of requirements for the module. |
Source code in zenml/integrations/utils.py
def get_requirements_for_module(module_name: str) -> List[str]:
"""Gets requirements for a module inside an integration.
If the module given by `module_name` is not part of a ZenML integration,
this method will return an empty list. If it is part of a ZenML integration,
it will return the list of requirements specified inside the integration
class found inside the integration __init__ file.
Args:
module_name: The name of the module to get requirements for.
Returns:
A list of requirements for the module.
"""
integration = get_integration_for_module(module_name)
return integration.REQUIREMENTS if integration else []
vault
special
Initialization for the Vault Secrets Manager integration.
The Vault secrets manager integration submodule provides a way to access the HashiCorp Vault secrets manager from within your ZenML pipeline runs.
VaultSecretsManagerIntegration (Integration)
Definition of HashiCorp Vault integration with ZenML.
Source code in zenml/integrations/vault/__init__.py
class VaultSecretsManagerIntegration(Integration):
"""Definition of HashiCorp Vault integration with ZenML."""
NAME = VAULT
REQUIREMENTS = ["hvac>=0.11.2"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Vault integration.
Returns:
List of stack component flavors.
"""
from zenml.integrations.vault.flavors import VaultSecretsManagerFlavor
return [VaultSecretsManagerFlavor]
flavors()
classmethod
Declare the stack component flavors for the Vault integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors. |
Source code in zenml/integrations/vault/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Vault integration.
Returns:
List of stack component flavors.
"""
from zenml.integrations.vault.flavors import VaultSecretsManagerFlavor
return [VaultSecretsManagerFlavor]
flavors
special
HashiCorp Vault integration flavors.
vault_secrets_manager_flavor
HashiCorp Vault secrets manager flavor.
VaultSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the Vault Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
url |
str |
The url of the Vault server. |
token |
str |
The token to use to authenticate with Vault. |
cert |
Optional[str] |
The path to the certificate to use to authenticate with Vault. |
verify |
Optional[str] |
Whether to verify the certificate or not. |
mount_point |
str |
The mount point of the secrets manager. |
Source code in zenml/integrations/vault/flavors/vault_secrets_manager_flavor.py
class VaultSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the Vault Secrets Manager.
Attributes:
url: The url of the Vault server.
token: The token to use to authenticate with Vault.
cert: The path to the certificate to use to authenticate with Vault.
verify: Whether to verify the certificate or not.
mount_point: The mount point of the secrets manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
url: str
token: str
mount_point: str
cert: Optional[str]
verify: Optional[str]
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_vault_secret_name_or_namespace(namespace)
VaultSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the VaultSecretsManagerFlavor
.
Source code in zenml/integrations/vault/flavors/vault_secrets_manager_flavor.py
class VaultSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `VaultSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return VAULT_SECRETS_MANAGER_FLAVOR
@property
def config_class(self) -> Type[VaultSecretsManagerConfig]:
"""Returns `VaultSecretsManagerConfig` config class.
Returns:
The config class.
"""
return VaultSecretsManagerConfig
@property
def implementation_class(self) -> Type["VaultSecretsManager"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.vault.secrets_manager import VaultSecretsManager
return VaultSecretsManager
config_class: Type[zenml.integrations.vault.flavors.vault_secrets_manager_flavor.VaultSecretsManagerConfig]
property
readonly
Returns VaultSecretsManagerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.vault.flavors.vault_secrets_manager_flavor.VaultSecretsManagerConfig] |
The config class. |
implementation_class: Type[VaultSecretsManager]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[VaultSecretsManager] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
validate_vault_secret_name_or_namespace(name)
Validate a secret name or namespace.
For compatibility across secret managers the secret names should contain
only alphanumeric characters and the characters /_+=.@-. The /
character is only used internally to delimit scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/vault/flavors/vault_secrets_manager_flavor.py
def validate_vault_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
For compatibility across secret managers the secret names should contain
only alphanumeric characters and the characters /_+=.@-. The `/`
character is only used internally to delimit scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
secrets_manager
special
HashiCorp Vault Secrets Manager.
vault_secrets_manager
Implementation of the HashiCorp Vault Secrets Manager integration.
VaultSecretsManager (BaseSecretsManager)
Class to interact with the Vault secrets manager - Key/value Engine.
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
class VaultSecretsManager(BaseSecretsManager):
"""Class to interact with the Vault secrets manager - Key/value Engine."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> VaultSecretsManagerConfig:
"""Returns the `VaultSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(VaultSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls, url: str, token: str) -> None:
"""Ensure the client is connected.
This function initializes the client if it is not initialized.
Args:
url: The url of the Vault server.
token: The token to use to authenticate with Vault.
"""
if cls.CLIENT is None:
# Create a Vault Secrets Manager client
cls.CLIENT = hvac.Client(
url=url,
token=token,
)
def _ensure_client_is_authenticated(self) -> None:
"""Ensure the client is authenticated.
Raises:
RuntimeError: If the client is not initialized or authenticated.
"""
self._ensure_client_connected(
url=self.config.url, token=self.config.token
)
if not self.CLIENT.is_authenticated():
raise RuntimeError(
"There was an error authenticating with Vault. Please check "
"your configuration."
)
else:
pass
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If the secret already exists.
"""
self._ensure_client_is_authenticated()
validate_vault_secret_name_or_namespace(secret.name)
try:
self.get_secret(secret.name)
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already " f"exists."
)
except KeyError:
pass
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.config.mount_point,
secret=secret_value,
)
logger.info("Created secret: %s", f"{secret_path}")
logger.info("Added value to secret.")
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
try:
secret_items = (
self.CLIENT.secrets.kv.v2.read_secret_version(
path=secret_path,
mount_point=self.config.mount_point,
)
.get("data", {})
.get("data", {})
)
except InvalidPath as e:
raise KeyError(e)
zenml_schema_name = secret_items.pop(ZENML_SCHEMA_NAME)
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
secret_items["name"] = secret_name
return secret_schema(**secret_items)
def get_all_secret_keys(self) -> List[str]:
"""List all secrets in Vault without any reformatting.
This function tries to get all secrets from Vault and returns
them as a list of strings (all secrets' names).
Returns:
A list of all secrets in the secrets manager.
"""
self._ensure_client_is_authenticated()
set_of_secrets: Set[str] = set()
secret_path = "/".join(self._get_scope_path())
try:
secrets = self.CLIENT.secrets.kv.v2.list_secrets(
path=secret_path, mount_point=self.config.mount_point
)
except hvac.exceptions.InvalidPath:
logger.error(
f"There are no secrets created within the scope `{secret_path}`"
)
return list(set_of_secrets)
secrets_keys = secrets.get("data", {}).get("keys", [])
for secret_key in secrets_keys:
# vault scopes end with / and are not themselves secrets
if "/" not in secret_key:
set_of_secrets.add(secret_key)
return list(set_of_secrets)
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
validate_vault_secret_name_or_namespace(secret.name)
if secret.name in self.get_all_secret_keys():
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.config.mount_point,
secret=secret_value,
)
else:
raise KeyError(
f"A Secret with the name '{secret.name}'" f" does not exist."
)
logger.info("Updated secret: %s", secret_path)
logger.info("Added value to secret.")
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
path=secret_path,
mount_point=self.config.mount_point,
)
logger.info("Deleted secret: %s", f"{secret_path}")
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_is_authenticated()
for secret_name in self.get_all_secret_keys():
self.delete_secret(secret_name)
logger.info("Deleted all secrets.")
config: VaultSecretsManagerConfig
property
readonly
Returns the VaultSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
VaultSecretsManagerConfig |
The configuration. |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_is_authenticated()
for secret_name in self.get_all_secret_keys():
self.delete_secret(secret_name)
logger.info("Deleted all secrets.")
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to delete. |
required |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: The name of the secret to delete.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
path=secret_path,
mount_point=self.config.mount_point,
)
logger.info("Deleted secret: %s", f"{secret_path}")
get_all_secret_keys(self)
List all secrets in Vault without any reformatting.
This function tries to get all secrets from Vault and returns them as a list of strings (all secrets' names).
Returns:
Type | Description |
---|---|
List[str] |
A list of all secrets in the secrets manager. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""List all secrets in Vault without any reformatting.
This function tries to get all secrets from Vault and returns
them as a list of strings (all secrets' names).
Returns:
A list of all secrets in the secrets manager.
"""
self._ensure_client_is_authenticated()
set_of_secrets: Set[str] = set()
secret_path = "/".join(self._get_scope_path())
try:
secrets = self.CLIENT.secrets.kv.v2.list_secrets(
path=secret_path, mount_point=self.config.mount_point
)
except hvac.exceptions.InvalidPath:
logger.error(
f"There are no secrets created within the scope `{secret_path}`"
)
return list(set_of_secrets)
secrets_keys = secrets.get("data", {}).get("keys", [])
for secret_key in secrets_keys:
# vault scopes end with / and are not themselves secrets
if "/" not in secret_key:
set_of_secrets.add(secret_key)
return list(set_of_secrets)
get_secret(self, secret_name)
Gets the value of a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
The name of the secret to get. |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the secret does not exist. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets the value of a secret.
Args:
secret_name: The name of the secret to get.
Returns:
The secret.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
secret_path = self._get_scoped_secret_name(secret_name)
try:
secret_items = (
self.CLIENT.secrets.kv.v2.read_secret_version(
path=secret_path,
mount_point=self.config.mount_point,
)
.get("data", {})
.get("data", {})
)
except InvalidPath as e:
raise KeyError(e)
zenml_schema_name = secret_items.pop(ZENML_SCHEMA_NAME)
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
secret_items["name"] = secret_name
return secret_schema(**secret_items)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to register. |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
If the secret already exists. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: The secret to register.
Raises:
SecretExistsError: If the secret already exists.
"""
self._ensure_client_is_authenticated()
validate_vault_secret_name_or_namespace(secret.name)
try:
self.get_secret(secret.name)
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already " f"exists."
)
except KeyError:
pass
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.config.mount_point,
secret=secret_value,
)
logger.info("Created secret: %s", f"{secret_path}")
logger.info("Added value to secret.")
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
The secret to update. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If the secret does not exist. |
Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: The secret to update.
Raises:
KeyError: If the secret does not exist.
"""
self._ensure_client_is_authenticated()
validate_vault_secret_name_or_namespace(secret.name)
if secret.name in self.get_all_secret_keys():
secret_path = self._get_scoped_secret_name(secret.name)
secret_value = secret_to_dict(secret, encode=False)
self.CLIENT.secrets.kv.v2.create_or_update_secret(
path=secret_path,
mount_point=self.config.mount_point,
secret=secret_value,
)
else:
raise KeyError(
f"A Secret with the name '{secret.name}'" f" does not exist."
)
logger.info("Updated secret: %s", secret_path)
logger.info("Added value to secret.")
wandb
special
Initialization for the wandb integration.
The wandb integrations currently enables you to use wandb tracking as a convenient way to visualize your experiment runs within the wandb ui.
WandbIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/wandb/__init__.py
class WandbIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = WANDB
REQUIREMENTS = ["wandb>=0.12.12", "Pillow>=9.1.0"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Weights and Biases integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.wandb.flavors import (
WandbExperimentTrackerFlavor,
)
return [WandbExperimentTrackerFlavor]
flavors()
classmethod
Declare the stack component flavors for the Weights and Biases integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/wandb/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Weights and Biases integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.wandb.flavors import (
WandbExperimentTrackerFlavor,
)
return [WandbExperimentTrackerFlavor]
experiment_trackers
special
Initialization for the wandb experiment tracker.
wandb_experiment_tracker
Implementation for the wandb experiment tracker.
WandbExperimentTracker (BaseExperimentTracker)
Track experiment using Wandb.
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
class WandbExperimentTracker(BaseExperimentTracker):
"""Track experiment using Wandb."""
@property
def config(self) -> WandbExperimentTrackerConfig:
"""Returns the `WandbExperimentTrackerConfig` config.
Returns:
The configuration.
"""
return cast(WandbExperimentTrackerConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""settings class for the Wandb experiment tracker.
Returns:
The settings class.
"""
return WandbExperimentTrackerSettings
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Configures a Wandb run.
Args:
info: Info about the step that will be executed.
"""
os.environ[WANDB_API_KEY] = self.config.api_key
settings = cast(
WandbExperimentTrackerSettings,
self.get_settings(info) or WandbExperimentTrackerSettings(),
)
tags = settings.tags + [info.run_name, info.pipeline.name]
wandb_run_name = (
settings.run_name or f"{info.run_name}_{info.config.name}"
)
self._initialize_wandb(
run_name=wandb_run_name, tags=tags, settings=settings.settings
)
def cleanup_step_run(self, info: "StepRunInfo") -> None:
"""Stops the Wandb run.
Args:
info: Info about the step that was executed.
"""
wandb.finish()
os.environ.pop(WANDB_API_KEY, None)
def _initialize_wandb(
self,
run_name: str,
tags: List[str],
settings: Union[wandb.Settings, Dict[str, Any], None] = None,
) -> None:
"""Initializes a wandb run.
Args:
run_name: Name of the wandb run to create.
tags: Tags to attach to the wandb run.
settings: Additional settings for the wandb run.
"""
logger.info(
f"Initializing wandb with entity {self.config.entity}, project "
f"name: {self.config.project_name}, run_name: {run_name}."
)
wandb.init(
entity=self.config.entity,
project=self.config.project_name,
name=run_name,
tags=tags,
settings=settings,
)
config: WandbExperimentTrackerConfig
property
readonly
Returns the WandbExperimentTrackerConfig
config.
Returns:
Type | Description |
---|---|
WandbExperimentTrackerConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
settings class for the Wandb experiment tracker.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
cleanup_step_run(self, info)
Stops the Wandb run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def cleanup_step_run(self, info: "StepRunInfo") -> None:
"""Stops the Wandb run.
Args:
info: Info about the step that was executed.
"""
wandb.finish()
os.environ.pop(WANDB_API_KEY, None)
prepare_step_run(self, info)
Configures a Wandb run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that will be executed. |
required |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Configures a Wandb run.
Args:
info: Info about the step that will be executed.
"""
os.environ[WANDB_API_KEY] = self.config.api_key
settings = cast(
WandbExperimentTrackerSettings,
self.get_settings(info) or WandbExperimentTrackerSettings(),
)
tags = settings.tags + [info.run_name, info.pipeline.name]
wandb_run_name = (
settings.run_name or f"{info.run_name}_{info.config.name}"
)
self._initialize_wandb(
run_name=wandb_run_name, tags=tags, settings=settings.settings
)
flavors
special
Weights & Biases integration flavors.
wandb_experiment_tracker_flavor
Weights & Biases experiment tracker flavor.
WandbExperimentTrackerConfig (BaseExperimentTrackerConfig)
pydantic-model
Config for the Wandb experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
entity |
Optional[str] |
Name of an existing wandb entity. |
project_name |
Optional[str] |
Name of an existing wandb project to log to. |
api_key |
str |
API key to should be authorized to log to the configured wandb entity and project. |
Source code in zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py
class WandbExperimentTrackerConfig(BaseExperimentTrackerConfig):
"""Config for the Wandb experiment tracker.
Attributes:
entity: Name of an existing wandb entity.
project_name: Name of an existing wandb project to log to.
api_key: API key to should be authorized to log to the configured wandb
entity and project.
"""
api_key: str = SecretField()
entity: Optional[str] = None
project_name: Optional[str] = None
WandbExperimentTrackerFlavor (BaseExperimentTrackerFlavor)
Flavor for the Wandb experiment tracker.
Source code in zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py
class WandbExperimentTrackerFlavor(BaseExperimentTrackerFlavor):
"""Flavor for the Wandb experiment tracker."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return WANDB_EXPERIMENT_TRACKER_FLAVOR
@property
def config_class(self) -> Type[WandbExperimentTrackerConfig]:
"""Returns `WandbExperimentTrackerConfig` config class.
Returns:
The config class.
"""
return WandbExperimentTrackerConfig
@property
def implementation_class(self) -> Type["WandbExperimentTracker"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.wandb.experiment_trackers import (
WandbExperimentTracker,
)
return WandbExperimentTracker
config_class: Type[zenml.integrations.wandb.flavors.wandb_experiment_tracker_flavor.WandbExperimentTrackerConfig]
property
readonly
Returns WandbExperimentTrackerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.wandb.flavors.wandb_experiment_tracker_flavor.WandbExperimentTrackerConfig] |
The config class. |
implementation_class: Type[WandbExperimentTracker]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[WandbExperimentTracker] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
WandbExperimentTrackerSettings (BaseSettings)
pydantic-model
Settings for the Wandb experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
run_name |
Optional[str] |
The Wandb run name. |
tags |
List[str] |
Tags for the Wandb run. |
settings |
Dict[str, Any] |
Settings for the Wandb run. |
Source code in zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py
class WandbExperimentTrackerSettings(BaseSettings):
"""Settings for the Wandb experiment tracker.
Attributes:
run_name: The Wandb run name.
tags: Tags for the Wandb run.
settings: Settings for the Wandb run.
"""
run_name: Optional[str] = None
tags: List[str] = []
settings: Dict[str, Any] = {}
@validator("settings", pre=True)
def _convert_settings(
cls, value: Union[Dict[str, Any], "wandb.Settings"]
) -> Dict[str, Any]:
"""Converts settings to a dictionary.
Args:
value: The settings.
Returns:
Dict representation of the settings.
"""
import wandb
if isinstance(value, wandb.Settings):
return cast(Dict[str, Any], value.make_static())
else:
return value
whylogs
special
Initialization of the whylogs integration.
WhylogsIntegration (Integration)
Definition of whylogs integration for ZenML.
Source code in zenml/integrations/whylogs/__init__.py
class WhylogsIntegration(Integration):
"""Definition of [whylogs](https://github.com/whylabs/whylogs) integration for ZenML."""
NAME = WHYLOGS
REQUIREMENTS = ["whylogs[viz]~=1.0.5", "whylogs[whylabs]~=1.0.5"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.whylogs import materializers # noqa
from zenml.integrations.whylogs import secret_schemas # noqa
from zenml.integrations.whylogs import visualizers # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.whylogs.flavors import (
WhylogsDataValidatorFlavor,
)
return [WhylogsDataValidatorFlavor]
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.whylogs import materializers # noqa
from zenml.integrations.whylogs import secret_schemas # noqa
from zenml.integrations.whylogs import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Great Expectations integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Great Expectations integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.whylogs.flavors import (
WhylogsDataValidatorFlavor,
)
return [WhylogsDataValidatorFlavor]
constants
Whylogs integration constants.
data_validators
special
Initialization of the whylogs data validator for ZenML.
whylogs_data_validator
Implementation of the whylogs data validator.
WhylogsDataValidator (BaseDataValidator, AuthenticationMixin)
Whylogs data validator stack component.
Attributes:
Name | Type | Description |
---|---|---|
authentication_secret |
Optional ZenML secret with Whylabs credentials. If configured, all the data profiles returned by all pipeline steps will automatically be uploaded to Whylabs in addition to being stored in the ZenML Artifact Store. |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
class WhylogsDataValidator(BaseDataValidator, AuthenticationMixin):
"""Whylogs data validator stack component.
Attributes:
authentication_secret: Optional ZenML secret with Whylabs credentials.
If configured, all the data profiles returned by all pipeline steps
will automatically be uploaded to Whylabs in addition to being
stored in the ZenML Artifact Store.
"""
NAME: ClassVar[str] = "whylogs"
@property
def config(self) -> WhylogsDataValidatorConfig:
"""Returns the `WhylogsDataValidatorConfig` config.
Returns:
The configuration.
"""
return cast(WhylogsDataValidatorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Whylogs data validator.
Returns:
The settings class.
"""
return WhylogsDataValidatorSettings
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Configures Whylabs logging.
Args:
info: Info about the step that will be executed.
"""
settings = cast(
WhylogsDataValidatorSettings,
self.get_settings(info) or WhylogsDataValidatorSettings(),
)
if settings.enable_whylabs:
os.environ[WHYLABS_LOGGING_ENABLED_ENV] = "true"
if settings.dataset_id:
os.environ[WHYLABS_DATASET_ID_ENV] = settings.dataset_id
def cleanup_step_run(self, info: "StepRunInfo") -> None:
"""Resets Whylabs configuration.
Args:
info: Info about the step that was executed.
"""
settings = cast(
WhylogsDataValidatorSettings,
self.get_settings(info) or WhylogsDataValidatorSettings(),
)
if settings.enable_whylabs:
del os.environ[WHYLABS_LOGGING_ENABLED_ENV]
if settings.dataset_id:
del os.environ[WHYLABS_DATASET_ID_ENV]
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
**kwargs: Any,
) -> DatasetProfileView:
"""Analyze a dataset and generate a data profile with whylogs.
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of whylogs
data profiles to be generated (unused).
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
**kwargs: Extra keyword arguments (unused).
Returns:
A whylogs profile view object.
"""
results = why.log(pandas=dataset)
profile = results.profile()
dataset_timestamp = dataset_timestamp or datetime.datetime.utcnow()
profile.set_dataset_timestamp(dataset_timestamp=dataset_timestamp)
return profile.view()
def upload_profile_view(
self,
profile_view: DatasetProfileView,
dataset_id: Optional[str] = None,
) -> None:
"""Upload a whylogs data profile view to Whylabs, if configured to do so.
Args:
profile_view: Whylogs profile view to upload.
dataset_id: Optional dataset identifier to use for the uploaded
data profile. If omitted, a dataset identifier will be retrieved
using other means, in order:
* the default dataset identifier configured in the Data
Validator secret
* a dataset ID will be generated automatically based on the
current pipeline/step information.
Raises:
ValueError: If the dataset ID was not provided and could not be
retrieved or inferred from other sources.
"""
secret = self.get_authentication_secret(
expected_schema_type=WhylabsSecretSchema
)
if not secret:
return
dataset_id = dataset_id or secret.whylabs_default_dataset_id
if not dataset_id:
# use the current pipeline name and the step name to generate a
# unique dataset name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
dataset_id = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
raise ValueError(
"A dataset ID was not specified and could not be "
"generated from the current pipeline and step name."
)
# Instantiate WhyLabs Writer
writer = WhyLabsWriter(
org_id=secret.whylabs_default_org_id,
api_key=secret.whylabs_api_key,
dataset_id=dataset_id,
)
# pass a profile view to the writer's write method
writer.write(profile=profile_view)
config: WhylogsDataValidatorConfig
property
readonly
Returns the WhylogsDataValidatorConfig
config.
Returns:
Type | Description |
---|---|
WhylogsDataValidatorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Whylogs data validator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
cleanup_step_run(self, info)
Resets Whylabs configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def cleanup_step_run(self, info: "StepRunInfo") -> None:
"""Resets Whylabs configuration.
Args:
info: Info about the step that was executed.
"""
settings = cast(
WhylogsDataValidatorSettings,
self.get_settings(info) or WhylogsDataValidatorSettings(),
)
if settings.enable_whylabs:
del os.environ[WHYLABS_LOGGING_ENABLED_ENV]
if settings.dataset_id:
del os.environ[WHYLABS_DATASET_ID_ENV]
data_profiling(self, dataset, comparison_dataset=None, profile_list=None, dataset_timestamp=None, **kwargs)
Analyze a dataset and generate a data profile with whylogs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
Target dataset to be profiled. |
required |
comparison_dataset |
Optional[pandas.core.frame.DataFrame] |
Optional dataset to be used for data profiles that require a baseline for comparison (e.g data drift profiles). |
None |
profile_list |
Optional[Sequence[str]] |
Optional list identifying the categories of whylogs data profiles to be generated (unused). |
None |
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
None |
**kwargs |
Any |
Extra keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
DatasetProfileView |
A whylogs profile view object. |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def data_profiling(
self,
dataset: pd.DataFrame,
comparison_dataset: Optional[pd.DataFrame] = None,
profile_list: Optional[Sequence[str]] = None,
dataset_timestamp: Optional[datetime.datetime] = None,
**kwargs: Any,
) -> DatasetProfileView:
"""Analyze a dataset and generate a data profile with whylogs.
Args:
dataset: Target dataset to be profiled.
comparison_dataset: Optional dataset to be used for data profiles
that require a baseline for comparison (e.g data drift profiles).
profile_list: Optional list identifying the categories of whylogs
data profiles to be generated (unused).
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
**kwargs: Extra keyword arguments (unused).
Returns:
A whylogs profile view object.
"""
results = why.log(pandas=dataset)
profile = results.profile()
dataset_timestamp = dataset_timestamp or datetime.datetime.utcnow()
profile.set_dataset_timestamp(dataset_timestamp=dataset_timestamp)
return profile.view()
prepare_step_run(self, info)
Configures Whylabs logging.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that will be executed. |
required |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Configures Whylabs logging.
Args:
info: Info about the step that will be executed.
"""
settings = cast(
WhylogsDataValidatorSettings,
self.get_settings(info) or WhylogsDataValidatorSettings(),
)
if settings.enable_whylabs:
os.environ[WHYLABS_LOGGING_ENABLED_ENV] = "true"
if settings.dataset_id:
os.environ[WHYLABS_DATASET_ID_ENV] = settings.dataset_id
upload_profile_view(self, profile_view, dataset_id=None)
Upload a whylogs data profile view to Whylabs, if configured to do so.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile_view |
DatasetProfileView |
Whylogs profile view to upload. |
required |
dataset_id |
Optional[str] |
Optional dataset identifier to use for the uploaded data profile. If omitted, a dataset identifier will be retrieved using other means, in order: * the default dataset identifier configured in the Data Validator secret * a dataset ID will be generated automatically based on the current pipeline/step information. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset ID was not provided and could not be retrieved or inferred from other sources. |
Source code in zenml/integrations/whylogs/data_validators/whylogs_data_validator.py
def upload_profile_view(
self,
profile_view: DatasetProfileView,
dataset_id: Optional[str] = None,
) -> None:
"""Upload a whylogs data profile view to Whylabs, if configured to do so.
Args:
profile_view: Whylogs profile view to upload.
dataset_id: Optional dataset identifier to use for the uploaded
data profile. If omitted, a dataset identifier will be retrieved
using other means, in order:
* the default dataset identifier configured in the Data
Validator secret
* a dataset ID will be generated automatically based on the
current pipeline/step information.
Raises:
ValueError: If the dataset ID was not provided and could not be
retrieved or inferred from other sources.
"""
secret = self.get_authentication_secret(
expected_schema_type=WhylabsSecretSchema
)
if not secret:
return
dataset_id = dataset_id or secret.whylabs_default_dataset_id
if not dataset_id:
# use the current pipeline name and the step name to generate a
# unique dataset name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
dataset_id = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
raise ValueError(
"A dataset ID was not specified and could not be "
"generated from the current pipeline and step name."
)
# Instantiate WhyLabs Writer
writer = WhyLabsWriter(
org_id=secret.whylabs_default_org_id,
api_key=secret.whylabs_api_key,
dataset_id=dataset_id,
)
# pass a profile view to the writer's write method
writer.write(profile=profile_view)
flavors
special
WhyLabs whylogs integration flavors.
whylogs_data_validator_flavor
WhyLabs whylogs data validator flavor.
WhylogsDataValidatorConfig (BaseDataValidatorConfig, AuthenticationConfigMixin)
pydantic-model
Config for the whylogs data validator.
Source code in zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py
class WhylogsDataValidatorConfig(
BaseDataValidatorConfig, AuthenticationConfigMixin
):
"""Config for the whylogs data validator."""
WhylogsDataValidatorFlavor (BaseDataValidatorFlavor)
Whylogs data validator flavor.
Source code in zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py
class WhylogsDataValidatorFlavor(BaseDataValidatorFlavor):
"""Whylogs data validator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return WHYLOGS_DATA_VALIDATOR_FLAVOR
@property
def config_class(self) -> Type[WhylogsDataValidatorConfig]:
"""Returns `WhylogsDataValidatorConfig` config class.
Returns:
The config class.
"""
return WhylogsDataValidatorConfig
@property
def implementation_class(self) -> Type["WhylogsDataValidator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.whylogs.data_validators import (
WhylogsDataValidator,
)
return WhylogsDataValidator
config_class: Type[zenml.integrations.whylogs.flavors.whylogs_data_validator_flavor.WhylogsDataValidatorConfig]
property
readonly
Returns WhylogsDataValidatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.whylogs.flavors.whylogs_data_validator_flavor.WhylogsDataValidatorConfig] |
The config class. |
implementation_class: Type[WhylogsDataValidator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[WhylogsDataValidator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
WhylogsDataValidatorSettings (BaseSettings)
pydantic-model
Settings for the Whylogs data validator.
Attributes:
Name | Type | Description |
---|---|---|
enable_whylabs |
bool |
If set to |
dataset_id |
Optional[str] |
Dataset ID to use when uploading profiles to Whylabs. |
Source code in zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py
class WhylogsDataValidatorSettings(BaseSettings):
"""Settings for the Whylogs data validator.
Attributes:
enable_whylabs: If set to `True` for a step, all the whylogs data
profile views returned by the step will automatically be uploaded
to the Whylabs platform if Whylabs credentials are configured.
dataset_id: Dataset ID to use when uploading profiles to Whylabs.
"""
enable_whylabs: bool = False
dataset_id: Optional[str] = None
materializers
special
Initialization of the whylogs materializer.
whylogs_materializer
Implementation of the whylogs materializer.
WhylogsMaterializer (BaseMaterializer)
Materializer to read/write whylogs dataset profile views.
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
class WhylogsMaterializer(BaseMaterializer):
"""Materializer to read/write whylogs dataset profile views."""
ASSOCIATED_TYPES = (DatasetProfileView,)
ASSOCIATED_ARTIFACT_TYPES = (StatisticsArtifact,)
def handle_input(self, data_type: Type[Any]) -> DatasetProfileView:
"""Reads and returns a whylogs dataset profile view.
Args:
data_type: The type of the data to read.
Returns:
A loaded whylogs dataset profile view object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
profile_view = DatasetProfileView.read(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return profile_view
def handle_return(self, profile_view: DatasetProfileView) -> None:
"""Writes a whylogs dataset profile view.
Args:
profile_view: A whylogs dataset profile view object.
"""
super().handle_return(profile_view)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
profile_view.write(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
# Use the data validator to upload the profile view to Whylabs,
# if configured to do so. This logic is only enabled if the pipeline
# step was decorated with the `enable_whylabs` decorator
whylabs_enabled = os.environ.get(WHYLABS_LOGGING_ENABLED_ENV)
if not whylabs_enabled:
return
dataset_id = os.environ.get(WHYLABS_DATASET_ID_ENV)
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
data_validator.upload_profile_view(profile_view, dataset_id=dataset_id)
handle_input(self, data_type)
Reads and returns a whylogs dataset profile view.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
DatasetProfileView |
A loaded whylogs dataset profile view object. |
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_input(self, data_type: Type[Any]) -> DatasetProfileView:
"""Reads and returns a whylogs dataset profile view.
Args:
data_type: The type of the data to read.
Returns:
A loaded whylogs dataset profile view object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
profile_view = DatasetProfileView.read(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return profile_view
handle_return(self, profile_view)
Writes a whylogs dataset profile view.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile_view |
DatasetProfileView |
A whylogs dataset profile view object. |
required |
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_return(self, profile_view: DatasetProfileView) -> None:
"""Writes a whylogs dataset profile view.
Args:
profile_view: A whylogs dataset profile view object.
"""
super().handle_return(profile_view)
filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), PROFILE_FILENAME)
profile_view.write(temp_file)
# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
# Use the data validator to upload the profile view to Whylabs,
# if configured to do so. This logic is only enabled if the pipeline
# step was decorated with the `enable_whylabs` decorator
whylabs_enabled = os.environ.get(WHYLABS_LOGGING_ENABLED_ENV)
if not whylabs_enabled:
return
dataset_id = os.environ.get(WHYLABS_DATASET_ID_ENV)
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
data_validator.upload_profile_view(profile_view, dataset_id=dataset_id)
secret_schemas
special
Initialization for the Whylabs secret schema.
This schema can be used to configure a ZenML secret to authenticate ZenML to use the Whylabs platform to automatically log all whylogs data profiles generated and by pipeline steps.
whylabs_secret_schema
Implementation for Seldon secret schemas.
WhylabsSecretSchema (BaseSecretSchema)
pydantic-model
Whylabs credentials.
Attributes:
Name | Type | Description |
---|---|---|
whylabs_default_org_id |
str |
the Whylabs organization ID. |
whylabs_api_key |
str |
Whylabs API key. |
whylabs_default_dataset_id |
Optional[str] |
default Whylabs dataset ID to use when logging data profiles. |
Source code in zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py
class WhylabsSecretSchema(BaseSecretSchema):
"""Whylabs credentials.
Attributes:
whylabs_default_org_id: the Whylabs organization ID.
whylabs_api_key: Whylabs API key.
whylabs_default_dataset_id: default Whylabs dataset ID to use when
logging data profiles.
"""
TYPE: ClassVar[str] = WHYLABS_SECRET_SCHEMA_TYPE
whylabs_default_org_id: str
whylabs_api_key: str
whylabs_default_dataset_id: Optional[str] = None
steps
special
Initialization of the whylogs steps.
whylogs_profiler
Implementation of the whylogs profiler step.
WhylogsProfilerParameters (BaseAnalyzerParameters)
pydantic-model
Parameters class for the WhylogsProfiler step.
Attributes:
Name | Type | Description |
---|---|---|
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerParameters(BaseAnalyzerParameters):
"""Parameters class for the WhylogsProfiler step.
Attributes:
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
"""
dataset_timestamp: Optional[datetime.datetime]
WhylogsProfilerStep (BaseAnalyzerStep)
Generates a whylogs data profile from a given pd.DataFrame.
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerStep(BaseAnalyzerStep):
"""Generates a whylogs data profile from a given pd.DataFrame."""
@staticmethod
def entrypoint( # type: ignore[override]
dataset: pd.DataFrame,
params: WhylogsProfilerParameters,
) -> DatasetProfileView:
"""Main entrypoint function for the whylogs profiler.
Args:
dataset: pd.DataFrame, the given dataset
params: the parameters of the step
Returns:
whylogs profile with statistics generated for the input dataset
"""
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
return data_validator.data_profiling(
dataset, dataset_timestamp=params.dataset_timestamp
)
PARAMETERS_CLASS (BaseAnalyzerParameters)
pydantic-model
Parameters class for the WhylogsProfiler step.
Attributes:
Name | Type | Description |
---|---|---|
dataset_timestamp |
Optional[datetime.datetime] |
timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied. |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerParameters(BaseAnalyzerParameters):
"""Parameters class for the WhylogsProfiler step.
Attributes:
dataset_timestamp: timestamp to associate with the generated
dataset profile (Optional). The current time is used if not
supplied.
"""
dataset_timestamp: Optional[datetime.datetime]
entrypoint(dataset, params)
staticmethod
Main entrypoint function for the whylogs profiler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
pd.DataFrame, the given dataset |
required |
params |
WhylogsProfilerParameters |
the parameters of the step |
required |
Returns:
Type | Description |
---|---|
DatasetProfileView |
whylogs profile with statistics generated for the input dataset |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
@staticmethod
def entrypoint( # type: ignore[override]
dataset: pd.DataFrame,
params: WhylogsProfilerParameters,
) -> DatasetProfileView:
"""Main entrypoint function for the whylogs profiler.
Args:
dataset: pd.DataFrame, the given dataset
params: the parameters of the step
Returns:
whylogs profile with statistics generated for the input dataset
"""
data_validator = cast(
WhylogsDataValidator,
WhylogsDataValidator.get_active_data_validator(),
)
return data_validator.data_profiling(
dataset, dataset_timestamp=params.dataset_timestamp
)
whylogs_profiler_step(step_name, params, dataset_id=None)
Shortcut function to create a new instance of the WhylogsProfilerStep step.
The returned WhylogsProfilerStep can be used in a pipeline to generate a whylogs DatasetProfileView from a given pd.DataFrame and save it as an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
params |
WhylogsProfilerParameters |
The step parameters |
required |
dataset_id |
Optional[str] |
Optional dataset ID to use to upload the profile to Whylabs. |
None |
Returns:
Type | Description |
---|---|
BaseStep |
a WhylogsProfilerStep step instance |
Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def whylogs_profiler_step(
step_name: str,
params: WhylogsProfilerParameters,
dataset_id: Optional[str] = None,
) -> BaseStep:
"""Shortcut function to create a new instance of the WhylogsProfilerStep step.
The returned WhylogsProfilerStep can be used in a pipeline to generate a
whylogs DatasetProfileView from a given pd.DataFrame and save it as an
artifact.
Args:
step_name: The name of the step
params: The step parameters
dataset_id: Optional dataset ID to use to upload the profile to Whylabs.
Returns:
a WhylogsProfilerStep step instance
"""
step_class = clone_step(WhylogsProfilerStep, step_name)
step_instance = step_class(params=params)
key = settings_utils.get_flavor_setting_key(WhylogsDataValidatorFlavor())
settings = WhylogsDataValidatorSettings(
enable_whylabs=True, dataset_id=dataset_id
)
step_instance.configure(settings={key: settings})
return step_instance
visualizers
special
Initialization of the whylogs visualizer.
whylogs_visualizer
Implementation of the whylogs visualizer step.
WhylogsVisualizer (BaseVisualizer)
The implementation of a Whylogs Visualizer.
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsVisualizer(BaseVisualizer):
"""The implementation of a Whylogs Visualizer."""
def visualize(
self,
object: StepView,
reference_step_view: Optional[StepView] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Visualize whylogs dataset profiles present as outputs in the step view.
Args:
object: StepView fetched from run.get_step().
reference_step_view: second StepView fetched from run.get_step() to
use as a reference to visualize data drift
*args: additional positional arguments to pass to the visualize
method
**kwargs: additional keyword arguments to pass to the visualize
method
"""
def extract_profile(
step_view: StepView,
) -> Optional[DatasetProfileView]:
"""Extract a whylogs DatasetProfileView from a step view.
Args:
step_view: a step view
Returns:
A whylogs DatasetProfileView object loaded from the step view,
if one could be found, otherwise None.
"""
whylogs_artifact_datatype = (
f"{DatasetProfileView.__module__}.{DatasetProfileView.__name__}"
)
for _, artifact_view in step_view.outputs.items():
# filter out anything but whylogs dataset profile artifacts
if artifact_view.data_type == whylogs_artifact_datatype:
profile = artifact_view.read()
return cast(DatasetProfileView, profile)
return None
profile = extract_profile(object)
reference_profile: Optional[DatasetProfileView] = None
if reference_step_view:
reference_profile = extract_profile(reference_step_view)
self.visualize_profile(profile, reference_profile)
def visualize_profile(
self,
profile: DatasetProfileView,
reference_profile: Optional[DatasetProfileView] = None,
) -> None:
"""Generate a visualization of one or two whylogs dataset profile.
Args:
profile: whylogs DatasetProfileView to visualize
reference_profile: second optional DatasetProfileView to use to
generate a data drift visualization
"""
# currently, whylogs doesn't support visualizing a single profile, so
# we trick it by using the same profile twice, both as reference and
# target, in a drift report
reference_profile = reference_profile or profile
visualization = NotebookProfileVisualizer()
visualization.set_profiles(
target_profile_view=profile,
reference_profile_view=reference_profile,
)
rendered_html = visualization.summary_drift_report()
if Environment.in_notebook():
from IPython.core.display import display
display(rendered_html)
for column in sorted(list(profile.get_columns().keys())):
display(visualization.double_histogram(feature_name=column))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(rendered_html.data)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
visualize(self, object, reference_step_view=None, *args, **kwargs)
Visualize whylogs dataset profiles present as outputs in the step view.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
StepView |
StepView fetched from run.get_step(). |
required |
reference_step_view |
Optional[zenml.post_execution.step.StepView] |
second StepView fetched from run.get_step() to use as a reference to visualize data drift |
None |
*args |
Any |
additional positional arguments to pass to the visualize method |
() |
**kwargs |
Any |
additional keyword arguments to pass to the visualize method |
{} |
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize(
self,
object: StepView,
reference_step_view: Optional[StepView] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Visualize whylogs dataset profiles present as outputs in the step view.
Args:
object: StepView fetched from run.get_step().
reference_step_view: second StepView fetched from run.get_step() to
use as a reference to visualize data drift
*args: additional positional arguments to pass to the visualize
method
**kwargs: additional keyword arguments to pass to the visualize
method
"""
def extract_profile(
step_view: StepView,
) -> Optional[DatasetProfileView]:
"""Extract a whylogs DatasetProfileView from a step view.
Args:
step_view: a step view
Returns:
A whylogs DatasetProfileView object loaded from the step view,
if one could be found, otherwise None.
"""
whylogs_artifact_datatype = (
f"{DatasetProfileView.__module__}.{DatasetProfileView.__name__}"
)
for _, artifact_view in step_view.outputs.items():
# filter out anything but whylogs dataset profile artifacts
if artifact_view.data_type == whylogs_artifact_datatype:
profile = artifact_view.read()
return cast(DatasetProfileView, profile)
return None
profile = extract_profile(object)
reference_profile: Optional[DatasetProfileView] = None
if reference_step_view:
reference_profile = extract_profile(reference_step_view)
self.visualize_profile(profile, reference_profile)
visualize_profile(self, profile, reference_profile=None)
Generate a visualization of one or two whylogs dataset profile.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
profile |
DatasetProfileView |
whylogs DatasetProfileView to visualize |
required |
reference_profile |
Optional[whylogs.core.view.dataset_profile_view.DatasetProfileView] |
second optional DatasetProfileView to use to generate a data drift visualization |
None |
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize_profile(
self,
profile: DatasetProfileView,
reference_profile: Optional[DatasetProfileView] = None,
) -> None:
"""Generate a visualization of one or two whylogs dataset profile.
Args:
profile: whylogs DatasetProfileView to visualize
reference_profile: second optional DatasetProfileView to use to
generate a data drift visualization
"""
# currently, whylogs doesn't support visualizing a single profile, so
# we trick it by using the same profile twice, both as reference and
# target, in a drift report
reference_profile = reference_profile or profile
visualization = NotebookProfileVisualizer()
visualization.set_profiles(
target_profile_view=profile,
reference_profile_view=reference_profile,
)
rendered_html = visualization.summary_drift_report()
if Environment.in_notebook():
from IPython.core.display import display
display(rendered_html)
for column in sorted(list(profile.get_columns().keys())):
display(visualization.double_histogram(feature_name=column))
else:
logger.warning(
"The magic functions are only usable in a Jupyter notebook."
)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".html", encoding="utf-8"
) as f:
f.write(rendered_html.data)
url = f"file:///{f.name}"
logger.info("Opening %s in a new browser.." % f.name)
webbrowser.open(url, new=2)
xgboost
special
Initialization of the XGBoost integration.
XgboostIntegration (Integration)
Definition of xgboost integration for ZenML.
Source code in zenml/integrations/xgboost/__init__.py
class XgboostIntegration(Integration):
"""Definition of xgboost integration for ZenML."""
NAME = XGBOOST
REQUIREMENTS = ["xgboost>=1.0.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.xgboost import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/xgboost/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.xgboost import materializers # noqa
materializers
special
Initialization of the XGBoost materializers.
xgboost_booster_materializer
Implementation of an XGBoost booster materializer.
XgboostBoosterMaterializer (BaseMaterializer)
Materializer to read data to and from xgboost.Booster.
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
class XgboostBoosterMaterializer(BaseMaterializer):
"""Materializer to read data to and from xgboost.Booster."""
ASSOCIATED_TYPES = (xgb.Booster,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
"""Reads a xgboost Booster model from a serialized JSON file.
Args:
data_type: A xgboost Booster type.
Returns:
A xgboost Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = xgb.Booster()
booster.load_model(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
def handle_return(self, booster: xgb.Booster) -> None:
"""Creates a JSON serialization for a xgboost Booster model.
Args:
booster: A xgboost Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
handle_input(self, data_type)
Reads a xgboost Booster model from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
A xgboost Booster type. |
required |
Returns:
Type | Description |
---|---|
Booster |
A xgboost Booster object. |
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
"""Reads a xgboost Booster model from a serialized JSON file.
Args:
data_type: A xgboost Booster type.
Returns:
A xgboost Booster object.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = xgb.Booster()
booster.load_model(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return booster
handle_return(self, booster)
Creates a JSON serialization for a xgboost Booster model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
booster |
Booster |
A xgboost Booster model. |
required |
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_return(self, booster: xgb.Booster) -> None:
"""Creates a JSON serialization for a xgboost Booster model.
Args:
booster: A xgboost Booster model.
"""
super().handle_return(booster)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
booster.save_model(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
xgboost_dmatrix_materializer
Implementation of the XGBoost dmatrix materializer.
XgboostDMatrixMaterializer (BaseMaterializer)
Materializer to read data to and from xgboost.DMatrix.
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
class XgboostDMatrixMaterializer(BaseMaterializer):
"""Materializer to read data to and from xgboost.DMatrix."""
ASSOCIATED_TYPES = (xgb.DMatrix,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
"""Reads a xgboost.DMatrix binary file and loads it.
Args:
data_type: The datatype which should be read.
Returns:
Materialized xgboost matrix.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = xgb.DMatrix(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return matrix
def handle_return(self, matrix: xgb.DMatrix) -> None:
"""Creates a binary serialization for a xgboost.DMatrix object.
Args:
matrix: A xgboost.DMatrix object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
matrix.save_binary(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
handle_input(self, data_type)
Reads a xgboost.DMatrix binary file and loads it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The datatype which should be read. |
required |
Returns:
Type | Description |
---|---|
DMatrix |
Materialized xgboost matrix. |
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
"""Reads a xgboost.DMatrix binary file and loads it.
Args:
data_type: The datatype which should be read.
Returns:
Materialized xgboost matrix.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = xgb.DMatrix(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
return matrix
handle_return(self, matrix)
Creates a binary serialization for a xgboost.DMatrix object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
matrix |
DMatrix |
A xgboost.DMatrix object. |
required |
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_return(self, matrix: xgb.DMatrix) -> None:
"""Creates a binary serialization for a xgboost.DMatrix object.
Args:
matrix: A xgboost.DMatrix object.
"""
super().handle_return(matrix)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Make a temporary phantom artifact
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
matrix.save_binary(f.name)
# Copy it into artifact store
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)