Aws
zenml.integrations.aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = [
"sagemaker==2.117.0",
]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
@property
def config(self) -> AWSContainerRegistryConfig:
"""Returns the `AWSContainerRegistryConfig` config.
Returns:
The configuration.
"""
return cast(AWSContainerRegistryConfig, self._config)
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(
r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.config.uri}."
)
return match.group(1)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for example run a pipeline "
"using our Kubeflow orchestrator, ZenML will automatically build a "
f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create the "
f"repository `zenml-kubeflow` inside your amazon registry."
)
config: AWSContainerRegistryConfig
property
readonly
Returns the AWSContainerRegistryConfig
config.
Returns:
Type | Description |
---|---|
AWSContainerRegistryConfig |
The configuration. |
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
flavors
special
AWS integration flavors.
aws_container_registry_flavor
AWS container registry flavor.
AWSContainerRegistryConfig (BaseContainerRegistryConfig)
pydantic-model
Configuration for AWS Container Registry.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
"""Configuration for AWS Container Registry."""
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)
AWS Container Registry flavor.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
"""AWS Container Registry flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_CONTAINER_REGISTRY_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png"
@property
def config_class(self) -> Type[AWSContainerRegistryConfig]:
"""Config class for this flavor.
Returns:
The config class.
"""
return AWSContainerRegistryConfig
@property
def implementation_class(self) -> Type["AWSContainerRegistry"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.container_registries import (
AWSContainerRegistry,
)
return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[AWSContainerRegistry]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSContainerRegistry] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
aws_secrets_manager_flavor
AWS secrets manager flavor.
AWSSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the AWS Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
region_name |
str |
The region name of the AWS Secrets Manager. |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the AWS Secrets Manager.
Attributes:
region_name: The region name of the AWS Secrets Manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
region_name: str
@classmethod
def _validate_scope(
cls,
scope: "SecretsManagerScope",
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the AWSSecretsManagerFlavor
.
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `AWSSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return AWS_SECRET_MANAGER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/secrets_managers/aws.png"
@property
def config_class(self) -> Type[AWSSecretsManagerConfig]:
"""Config class for this flavor.
Returns:
Config class for this flavor.
"""
return AWSSecretsManagerConfig
@property
def implementation_class(self) -> Type["AWSSecretsManager"]:
"""Implementation class.
Returns:
Implementation class.
"""
from zenml.integrations.aws.secrets_managers import AWSSecretsManager
return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] |
Config class for this flavor. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[AWSSecretsManager]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSSecretsManager] |
Implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
validate_aws_secret_name_or_namespace(name)
Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The /
character is only used internally to delimit
scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
sagemaker_orchestrator_flavor
Amazon SageMaker orchestrator flavor.
SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings)
pydantic-model
Config for the Sagemaker orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
synchronous |
bool |
Whether to run the processing job synchronously or asynchronously. Defaults to False. |
execution_role |
str |
The IAM role ARN to use for the pipeline. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
"""Config for the Sagemaker orchestrator.
Attributes:
synchronous: Whether to run the processing job synchronously or
asynchronously. Defaults to False.
execution_role: The IAM role ARN to use for the pipeline.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format:
"sagemaker-{region}-{aws-account-id}".
"""
synchronous: bool = False
execution_role: str
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Sagemaker orchestrator.
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Sagemaker orchestrator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png"
@property
def config_class(self) -> Type[SagemakerOrchestratorConfig]:
"""Returns SagemakerOrchestratorConfig config class.
Returns:
The config class.
"""
return SagemakerOrchestratorConfig
@property
def implementation_class(self) -> Type["SagemakerOrchestrator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.orchestrators import SagemakerOrchestrator
return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]
property
readonly
Returns SagemakerOrchestratorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[SagemakerOrchestrator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerOrchestrator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
SagemakerOrchestratorSettings (BaseSettings)
pydantic-model
Settings for the Sagemaker orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
instance_type |
str |
The instance type to use for the processing job. |
processor_role |
Optional[str] |
The IAM role to use for the step execution on a Processor. |
volume_size_in_gb |
int |
The size of the EBS volume to use for the processing job. |
max_runtime_in_seconds |
int |
The maximum runtime in seconds for the processing job. |
processor_tags |
Dict[str, str] |
Tags to apply to the Processor assigned to the step. |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
"""Settings for the Sagemaker orchestrator.
Attributes:
instance_type: The instance type to use for the processing job.
processor_role: The IAM role to use for the step execution on a Processor.
volume_size_in_gb: The size of the EBS volume to use for the processing
job.
max_runtime_in_seconds: The maximum runtime in seconds for the
processing job.
processor_tags: Tags to apply to the Processor assigned to the step.
"""
instance_type: str = "ml.t3.medium"
processor_role: Optional[str] = None
volume_size_in_gb: int = 30
max_runtime_in_seconds: int = 86400
processor_tags: Dict[str, str] = {}
sagemaker_step_operator_flavor
Amazon SageMaker step operator flavor.
SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings)
pydantic-model
Config for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
"""Config for the Sagemaker step operator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
"""
role: str
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)
Flavor for the Sagemaker step operator.
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
"""Flavor for the Sagemaker step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png"
@property
def config_class(self) -> Type[SagemakerStepOperatorConfig]:
"""Returns SagemakerStepOperatorConfig config class.
Returns:
The config class.
"""
return SagemakerStepOperatorConfig
@property
def implementation_class(self) -> Type["SagemakerStepOperator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.step_operators import SagemakerStepOperator
return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]
property
readonly
Returns SagemakerStepOperatorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[SagemakerStepOperator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerStepOperator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
SagemakerStepOperatorSettings (BaseSettings)
pydantic-model
Settings for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
input_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where training data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with estimator_args.input_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3. |
estimator_args |
Dict[str, Any] |
Arguments that are directly passed to the SageMaker Estimator. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For estimator_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types. |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
"""Settings for the Sagemaker step operator.
Attributes:
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
input_data_s3_uri: S3 URI where training data is located if not locally,
e.g. s3://my-bucket/my-data/train. How data will be made available
to the container is configured with estimator_args.input_mode. Two possible
input types:
- str: S3 location where training data is saved.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. training, validation, testing) where
specific parts of the data are saved in S3.
estimator_args: Arguments that are directly passed to the SageMaker
Estimator. See
https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
for a full list of arguments.
For estimator_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
"""
instance_type: Optional[str] = None
experiment_name: Optional[str] = None
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None
estimator_args: Dict[str, Any] = {}
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
"instance_type"
)
orchestrators
special
AWS Sagemaker orchestrator.
sagemaker_orchestrator
Implementation of the SageMaker orchestrator.
SagemakerOrchestrator (ContainerizedOrchestrator)
Orchestrator responsible for running pipelines on Sagemaker.
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines on Sagemaker."""
@property
def config(self) -> SagemakerOrchestratorConfig:
"""Returns the `SagemakerOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerOrchestratorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
In the remote case, checks that the stack contains a container registry,
image builder and only remote components.
Returns:
A `StackValidator` instance.
"""
def _validate_remote_components(
stack: "Stack",
) -> Tuple[bool, str]:
for component in stack.components.values():
if not component.config.is_local:
continue
return False, (
f"The Sagemaker orchestrator runs pipelines remotely, "
f"but the '{component.name}' {component.type.value} is "
"a local stack component and will not be available in "
"the Sagemaker step.\nPlease ensure that you always "
"use non-local stack components with the Sagemaker "
"orchestrator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Sagemaker orchestrator.
Returns:
The settings class.
"""
return SagemakerOrchestratorSettings
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
environment: Dict[str, str],
) -> None:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
environment: Environment variables to set in the orchestration
environment.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
).replace("_", "-")
session = sagemaker.Session(default_bucket=self.config.bucket)
sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
processor_role = (
step_settings.processor_role or self.config.execution_role
)
kwargs = (
{"tags": [step_settings.processor_tags]}
if step_settings.processor_tags
else {}
)
environment[
ENV_ZENML_SAGEMAKER_RUN_ID
] = ExecutionVariables.PIPELINE_EXECUTION_ARN
processor = sagemaker.processing.Processor(
role=processor_role,
image_uri=image,
instance_count=1,
sagemaker_session=session,
instance_type=step_settings.instance_type,
entrypoint=entrypoint,
base_job_name=orchestrator_run_name,
env=environment,
volume_size_in_gb=step_settings.volume_size_in_gb,
max_runtime_in_seconds=step_settings.max_runtime_in_seconds,
**kwargs,
)
sagemaker_step = ProcessingStep(
name=step.config.name,
processor=processor,
depends_on=step.spec.upstream_steps,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
pipeline.create(role_arn=self.config.execution_role)
pipeline_execution = pipeline.start()
# mainly for testing purposes, we wait for the pipeline to finish
if self.config.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish..."
)
pipeline_execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
def _get_region_name(self) -> str:
"""Returns the AWS region name.
Returns:
The region name.
Raises:
RuntimeError: If the region name cannot be retrieved.
"""
try:
return cast(str, sagemaker.Session().boto_region_name)
except Exception as e:
raise RuntimeError(
"Unable to get region name. Please ensure that you have "
"configured your AWS credentials correctly."
) from e
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_metadata: Dict[str, "MetadataType"] = {
"pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
}
try:
region_name = self._get_region_name()
except RuntimeError:
logger.warning("Unable to get region name from AWS Sagemaker.")
return run_metadata
aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
orchestrator_logs_url = (
f"https://{region_name}.console.aws.amazon.com/"
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
f"$3Dpipelines-{aws_run_id}-"
)
run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
return run_metadata
config: SagemakerOrchestratorConfig
property
readonly
Returns the SagemakerOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerOrchestratorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Sagemaker orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
In the remote case, checks that the stack contains a container registry, image builder and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
get_orchestrator_run_id(self)
Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run id cannot be read from the environment. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
get_pipeline_run_metadata(self, run_id)
Get general component-specific metadata for a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
The ID of the pipeline run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_metadata: Dict[str, "MetadataType"] = {
"pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
}
try:
region_name = self._get_region_name()
except RuntimeError:
logger.warning("Unable to get region name from AWS Sagemaker.")
return run_metadata
aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
orchestrator_logs_url = (
f"https://{region_name}.console.aws.amazon.com/"
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
f"$3Dpipelines-{aws_run_id}-"
)
run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
return run_metadata
prepare_or_run_pipeline(self, deployment, stack, environment)
Prepares or runs a pipeline on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponseModel |
The deployment to prepare or run. |
required |
stack |
Stack |
The stack to run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
environment: Dict[str, str],
) -> None:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
environment: Environment variables to set in the orchestration
environment.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
).replace("_", "-")
session = sagemaker.Session(default_bucket=self.config.bucket)
sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
processor_role = (
step_settings.processor_role or self.config.execution_role
)
kwargs = (
{"tags": [step_settings.processor_tags]}
if step_settings.processor_tags
else {}
)
environment[
ENV_ZENML_SAGEMAKER_RUN_ID
] = ExecutionVariables.PIPELINE_EXECUTION_ARN
processor = sagemaker.processing.Processor(
role=processor_role,
image_uri=image,
instance_count=1,
sagemaker_session=session,
instance_type=step_settings.instance_type,
entrypoint=entrypoint,
base_job_name=orchestrator_run_name,
env=environment,
volume_size_in_gb=step_settings.volume_size_in_gb,
max_runtime_in_seconds=step_settings.max_runtime_in_seconds,
**kwargs,
)
sagemaker_step = ProcessingStep(
name=step.config.name,
processor=processor,
depends_on=step.spec.upstream_steps,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
pipeline.create(role_arn=self.config.execution_role)
pipeline_execution = pipeline.start()
# mainly for testing purposes, we wait for the pipeline to finish
if self.config.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish..."
)
pipeline_execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
secrets_managers
special
AWS Secrets Manager.
aws_secrets_manager
Implementation of the AWS Secrets Manager integration.
AWSSecretsManager (BaseSecretsManager)
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> AWSSecretsManagerConfig:
"""Returns the `AWSSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(AWSSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
"""Ensure that the client is connected to the AWS secrets manager.
Args:
region_name: the AWS region name
"""
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
def _get_secret_tags(
self, secret: BaseSecretSchema
) -> List[Dict[str, str]]:
"""Return a list of AWS secret tag values for a given secret.
Args:
secret: the secret object
Returns:
A list of AWS secret tag values
"""
metadata = self._get_secret_metadata(secret)
return [{"Key": k, "Value": v} for k, v in metadata.items()]
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return a list of AWS filters for the entire scope or just a scoped secret.
These filters can be used when querying the AWS Secrets Manager
for all secrets or for a single secret available in the configured
scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html
Example AWS filters for all secrets in the current (namespace) scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Example AWS filters for a particular secret in the current (namespace)
scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_secret_name"],
},
{
"Key: "tag-value",
"Values": ["my_secret"],
},
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of AWS filters uniquely identifying all secrets
or a named secret within the configured scope.
"""
metadata = self._get_secret_scope_metadata(secret_name)
filters: List[Dict[str, Any]] = []
for k, v in metadata.items():
filters.append(
{
"Key": "tag-key",
"Values": [
k,
],
}
)
filters.append(
{
"Key": "tag-value",
"Values": [
str(v),
],
}
)
return filters
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected(self.config.region_name)
filters: List[Dict[str, Any]] = []
prefix: Optional[str] = None
if self.config.scope == SecretsManagerScope.NONE:
# unscoped (legacy) secrets don't have tags. We want to filter out
# non-legacy secrets
filters = [
{
"Key": "tag-key",
"Values": [
"!zenml_scope",
],
},
]
if secret_name:
prefix = secret_name
else:
filters = self._get_secret_scope_filters()
if secret_name:
prefix = self._get_scoped_secret_name(secret_name)
else:
# add the name prefix to the filters to account for the fact
# that AWS does not do exact matching but prefix-matching on the
# filters
prefix = self._get_scoped_secret_name_prefix()
if prefix:
filters.append(
{
"Key": "name",
"Values": [
f"{prefix}",
],
}
)
paginator = self.CLIENT.get_paginator(_BOTO_CLIENT_LIST_SECRETS)
pages = paginator.paginate(
Filters=filters,
PaginationConfig={
"PageSize": 100,
},
)
results = []
for page in pages:
for secret in page[_PAGINATOR_RESPONSE_SECRETS_LIST_KEY]:
name = self._get_unscoped_secret_name(secret["Name"])
# keep only the names that are in scope and filter by secret name,
# if one was given
if name and (not secret_name or secret_name == name):
results.append(name)
return results
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
config: AWSSecretsManagerConfig
property
readonly
Returns the AWSSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
AWSSecretsManagerConfig |
The configuration. |
delete_all_secrets(self)
Delete all existing secrets.
This method will force delete all your secrets. You will not be able to recover them once this method is called.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator)
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
"""
@property
def config(self) -> SagemakerStepOperatorConfig:
"""Returns the `SagemakerStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerStepOperatorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the SageMaker step operator.
Returns:
The settings class.
"""
return SagemakerStepOperatorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the SageMaker "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"SageMaker step operator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_docker_builds(
self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=SAGEMAKER_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
builds.append(build)
return builds
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args
session = sagemaker.Session(default_bucket=self.config.bucket)
estimator_args.setdefault(
"instance_type", settings.instance_type or "ml.m5.large"
)
estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
# Create Estimator
estimator = sagemaker.estimator.Estimator(
image_name, self.config.role, **estimator_args
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
job_name = f"{info.run_name}-{info.pipeline_step_name}"
job_name = job_name.replace("_", "-")
# Construct training input object, if necessary
inputs = None
if isinstance(settings.input_data_s3_uri, str):
inputs = sagemaker.inputs.TrainingInput(
s3_data=settings.input_data_s3_uri
)
elif isinstance(settings.input_data_s3_uri, dict):
inputs = {}
for channel, s3_uri in settings.input_data_s3_uri.items():
inputs[channel] = sagemaker.inputs.TrainingInput(
s3_data=s3_uri
)
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": job_name,
}
estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=job_name,
)
config: SagemakerStepOperatorConfig
property
readonly
Returns the SagemakerStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerStepOperatorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the SageMaker step operator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
get_docker_builds(self, deployment)
Gets the Docker builds required for the component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentBaseModel |
The pipeline deployment for which to get the builds. |
required |
Returns:
Type | Description |
---|---|
List[BuildConfiguration] |
The required Docker builds. |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def get_docker_builds(
self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=SAGEMAKER_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
builds.append(build)
return builds
launch(self, info, entrypoint_command, environment)
Launches a step on SageMaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
environment |
Dict[str, str] |
Environment variables to set in the step operator environment. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args
session = sagemaker.Session(default_bucket=self.config.bucket)
estimator_args.setdefault(
"instance_type", settings.instance_type or "ml.m5.large"
)
estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
# Create Estimator
estimator = sagemaker.estimator.Estimator(
image_name, self.config.role, **estimator_args
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
job_name = f"{info.run_name}-{info.pipeline_step_name}"
job_name = job_name.replace("_", "-")
# Construct training input object, if necessary
inputs = None
if isinstance(settings.input_data_s3_uri, str):
inputs = sagemaker.inputs.TrainingInput(
s3_data=settings.input_data_s3_uri
)
elif isinstance(settings.input_data_s3_uri, dict):
inputs = {}
for channel, s3_uri in settings.input_data_s3_uri.items():
inputs[channel] = sagemaker.inputs.TrainingInput(
s3_data=s3_uri
)
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": job_name,
}
estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=job_name,
)