Aws
zenml.integrations.aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = [
"sagemaker==2.117.0",
]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
@property
def config(self) -> AWSContainerRegistryConfig:
"""Returns the `AWSContainerRegistryConfig` config.
Returns:
The configuration.
"""
return cast(AWSContainerRegistryConfig, self._config)
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(
r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.config.uri}."
)
return match.group(1)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for example run a pipeline "
"using our Kubeflow orchestrator, ZenML will automatically build a "
f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create the "
f"repository `zenml-kubeflow` inside your amazon registry."
)
config: AWSContainerRegistryConfig
property
readonly
Returns the AWSContainerRegistryConfig
config.
Returns:
Type | Description |
---|---|
AWSContainerRegistryConfig |
The configuration. |
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
flavors
special
AWS integration flavors.
aws_container_registry_flavor
AWS container registry flavor.
AWSContainerRegistryConfig (BaseContainerRegistryConfig)
pydantic-model
Configuration for AWS Container Registry.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
"""Configuration for AWS Container Registry."""
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)
AWS Container Registry flavor.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
"""AWS Container Registry flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_CONTAINER_REGISTRY_FLAVOR
@property
def config_class(self) -> Type[AWSContainerRegistryConfig]:
"""Config class for this flavor.
Returns:
The config class.
"""
return AWSContainerRegistryConfig
@property
def implementation_class(self) -> Type["AWSContainerRegistry"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.container_registries import (
AWSContainerRegistry,
)
return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] |
The config class. |
implementation_class: Type[AWSContainerRegistry]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSContainerRegistry] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
aws_secrets_manager_flavor
AWS secrets manager flavor.
AWSSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the AWS Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
region_name |
str |
The region name of the AWS Secrets Manager. |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the AWS Secrets Manager.
Attributes:
region_name: The region name of the AWS Secrets Manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
region_name: str
@classmethod
def _validate_scope(
cls,
scope: "SecretsManagerScope",
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the AWSSecretsManagerFlavor
.
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `AWSSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return AWS_SECRET_MANAGER_FLAVOR
@property
def config_class(self) -> Type[AWSSecretsManagerConfig]:
"""Config class for this flavor.
Returns:
Config class for this flavor.
"""
return AWSSecretsManagerConfig
@property
def implementation_class(self) -> Type["AWSSecretsManager"]:
"""Implementation class.
Returns:
Implementation class.
"""
from zenml.integrations.aws.secrets_managers import AWSSecretsManager
return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] |
Config class for this flavor. |
implementation_class: Type[AWSSecretsManager]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSSecretsManager] |
Implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
validate_aws_secret_name_or_namespace(name)
Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The /
character is only used internally to delimit
scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
sagemaker_orchestrator_flavor
Amazon SageMaker orchestrator flavor.
SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings)
pydantic-model
Config for the Sagemaker orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
synchronous |
bool |
Whether to run the processing job synchronously or asynchronously. Defaults to False. |
execution_role |
str |
The IAM role to use for the pipeline. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
"""Config for the Sagemaker orchestrator.
Attributes:
synchronous: Whether to run the processing job synchronously or
asynchronously. Defaults to False.
execution_role: The IAM role to use for the pipeline.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format:
"sagemaker-{region}-{aws-account-id}".
"""
synchronous: bool = False
execution_role: str
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Sagemaker orchestrator.
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Sagemaker orchestrator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def config_class(self) -> Type[SagemakerOrchestratorConfig]:
"""Returns SagemakerOrchestratorConfig config class.
Returns:
The config class.
"""
return SagemakerOrchestratorConfig
@property
def implementation_class(self) -> Type["SagemakerOrchestrator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.orchestrators import SagemakerOrchestrator
return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]
property
readonly
Returns SagemakerOrchestratorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] |
The config class. |
implementation_class: Type[SagemakerOrchestrator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerOrchestrator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
SagemakerOrchestratorSettings (BaseSettings)
pydantic-model
Settings for the Sagemaker orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
instance_type |
str |
The instance type to use for the processing job. |
processor_role |
Optional[str] |
The IAM role to use for the step execution on a Processor. |
volume_size_in_gb |
int |
The size of the EBS volume to use for the processing job. |
max_runtime_in_seconds |
int |
The maximum runtime in seconds for the processing job. |
processor_tags |
Dict[str, str] |
Tags to apply to the Processor assigned to the step. |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
"""Settings for the Sagemaker orchestrator.
Attributes:
instance_type: The instance type to use for the processing job.
processor_role: The IAM role to use for the step execution on a Processor.
volume_size_in_gb: The size of the EBS volume to use for the processing
job.
max_runtime_in_seconds: The maximum runtime in seconds for the
processing job.
processor_tags: Tags to apply to the Processor assigned to the step.
"""
instance_type: str = "ml.t3.medium"
processor_role: Optional[str] = None
volume_size_in_gb: int = 30
max_runtime_in_seconds: int = 86400
processor_tags: Dict[str, str] = {}
sagemaker_step_operator_flavor
Amazon SageMaker step operator flavor.
SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings)
pydantic-model
Config for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
"""Config for the Sagemaker step operator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
"""
role: str
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)
Flavor for the Sagemaker step operator.
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
"""Flavor for the Sagemaker step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def config_class(self) -> Type[SagemakerStepOperatorConfig]:
"""Returns SagemakerStepOperatorConfig config class.
Returns:
The config class.
"""
return SagemakerStepOperatorConfig
@property
def implementation_class(self) -> Type["SagemakerStepOperator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.step_operators import SagemakerStepOperator
return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]
property
readonly
Returns SagemakerStepOperatorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] |
The config class. |
implementation_class: Type[SagemakerStepOperator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerStepOperator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
SagemakerStepOperatorSettings (BaseSettings)
pydantic-model
Settings for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
instance_type |
Optional[str] |
The type of the compute instance where jobs will run. Check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types. |
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
"""Settings for the Sagemaker step operator.
Attributes:
instance_type: The type of the compute instance where jobs will run.
Check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
"""
instance_type: Optional[str] = None
experiment_name: Optional[str] = None
orchestrators
special
AWS Sagemaker orchestrator.
sagemaker_orchestrator
Implementation of the SageMaker orchestrator.
SagemakerOrchestrator (BaseOrchestrator)
Orchestrator responsible for running pipelines on Sagemaker.
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines on Sagemaker."""
@property
def config(self) -> SagemakerOrchestratorConfig:
"""Returns the `SagemakerOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerOrchestratorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
In the remote case, checks that the stack contains a container registry,
image builder and only remote components.
Returns:
A `StackValidator` instance.
"""
def _validate_remote_components(
stack: "Stack",
) -> Tuple[bool, str]:
for component in stack.components.values():
if not component.config.is_local:
continue
return False, (
f"The Sagemaker orchestrator runs pipelines remotely, "
f"but the '{component.name}' {component.type.value} is "
"a local stack component and will not be available in "
"the Sagemaker step.\nPlease ensure that you always "
"use non-local stack components with the Sagemaker "
"orchestrator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
def prepare_pipeline_deployment(
self, deployment: "PipelineDeployment", stack: "Stack"
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Sagemaker orchestrator.
Returns:
The settings class.
"""
return SagemakerOrchestratorSettings
def prepare_or_run_pipeline(
self, deployment: "PipelineDeployment", stack: "Stack"
) -> None:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline.name
).replace("_", "-")
session = sagemaker.Session(default_bucket=self.config.bucket)
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
sagemaker_steps = []
for step_name, step in deployment.steps.items():
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
processor_role = (
step_settings.processor_role or self.config.execution_role
)
kwargs = (
{"tags": [step_settings.processor_tags]}
if step_settings.processor_tags
else {}
)
processor = sagemaker.processing.Processor(
role=processor_role,
image_uri=image_name,
instance_count=1,
sagemaker_session=session,
instance_type=step_settings.instance_type,
entrypoint=entrypoint,
base_job_name=orchestrator_run_name,
env={
ENV_ZENML_SAGEMAKER_RUN_ID: ExecutionVariables.PIPELINE_EXECUTION_ARN,
},
volume_size_in_gb=step_settings.volume_size_in_gb,
max_runtime_in_seconds=step_settings.max_runtime_in_seconds,
**kwargs,
)
sagemaker_step = ProcessingStep(
name=step.config.name,
processor=processor,
depends_on=step.spec.upstream_steps,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
pipeline.create(role_arn=self.config.execution_role)
pipeline_execution = pipeline.start()
# mainly for testing purposes, we wait for the pipeline to finish
if self.config.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish..."
)
pipeline_execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
# TODO: Add this once we can get the region
# run_url = (
# f"https://{region}.console.aws.amazon.com/sagemaker/"
# f"home?region={region}"
# )
return {
"pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
}
config: SagemakerOrchestratorConfig
property
readonly
Returns the SagemakerOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerOrchestratorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Sagemaker orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
In the remote case, checks that the stack contains a container registry, image builder and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
get_orchestrator_run_id(self)
Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run id cannot be read from the environment. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
get_pipeline_run_metadata(self, run_id)
Get general component-specific metadata for a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
The ID of the pipeline run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
# TODO: Add this once we can get the region
# run_url = (
# f"https://{region}.console.aws.amazon.com/sagemaker/"
# f"home?region={region}"
# )
return {
"pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
}
prepare_or_run_pipeline(self, deployment, stack)
Prepares or runs a pipeline on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The deployment to prepare or run. |
required |
stack |
Stack |
The stack to run on. |
required |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
self, deployment: "PipelineDeployment", stack: "Stack"
) -> None:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline.name
).replace("_", "-")
session = sagemaker.Session(default_bucket=self.config.bucket)
image_name = deployment.pipeline.extra[ORCHESTRATOR_DOCKER_IMAGE_KEY]
sagemaker_steps = []
for step_name, step in deployment.steps.items():
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
processor_role = (
step_settings.processor_role or self.config.execution_role
)
kwargs = (
{"tags": [step_settings.processor_tags]}
if step_settings.processor_tags
else {}
)
processor = sagemaker.processing.Processor(
role=processor_role,
image_uri=image_name,
instance_count=1,
sagemaker_session=session,
instance_type=step_settings.instance_type,
entrypoint=entrypoint,
base_job_name=orchestrator_run_name,
env={
ENV_ZENML_SAGEMAKER_RUN_ID: ExecutionVariables.PIPELINE_EXECUTION_ARN,
},
volume_size_in_gb=step_settings.volume_size_in_gb,
max_runtime_in_seconds=step_settings.max_runtime_in_seconds,
**kwargs,
)
sagemaker_step = ProcessingStep(
name=step.config.name,
processor=processor,
depends_on=step.spec.upstream_steps,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
pipeline.create(role_arn=self.config.execution_role)
pipeline_execution = pipeline.start()
# mainly for testing purposes, we wait for the pipeline to finish
if self.config.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish..."
)
pipeline_execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_pipeline_deployment(
self, deployment: "PipelineDeployment", stack: "Stack"
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_docker_image(
deployment=deployment, stack=stack
)
deployment.add_extra(ORCHESTRATOR_DOCKER_IMAGE_KEY, repo_digest)
secrets_managers
special
AWS Secrets Manager.
aws_secrets_manager
Implementation of the AWS Secrets Manager integration.
AWSSecretsManager (BaseSecretsManager)
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> AWSSecretsManagerConfig:
"""Returns the `AWSSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(AWSSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
"""Ensure that the client is connected to the AWS secrets manager.
Args:
region_name: the AWS region name
"""
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
def _get_secret_tags(
self, secret: BaseSecretSchema
) -> List[Dict[str, str]]:
"""Return a list of AWS secret tag values for a given secret.
Args:
secret: the secret object
Returns:
A list of AWS secret tag values
"""
metadata = self._get_secret_metadata(secret)
return [{"Key": k, "Value": v} for k, v in metadata.items()]
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return a list of AWS filters for the entire scope or just a scoped secret.
These filters can be used when querying the AWS Secrets Manager
for all secrets or for a single secret available in the configured
scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html
Example AWS filters for all secrets in the current (namespace) scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Example AWS filters for a particular secret in the current (namespace)
scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_secret_name"],
},
{
"Key: "tag-value",
"Values": ["my_secret"],
},
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of AWS filters uniquely identifying all secrets
or a named secret within the configured scope.
"""
metadata = self._get_secret_scope_metadata(secret_name)
filters: List[Dict[str, Any]] = []
for k, v in metadata.items():
filters.append(
{
"Key": "tag-key",
"Values": [
k,
],
}
)
filters.append(
{
"Key": "tag-value",
"Values": [
str(v),
],
}
)
return filters
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected(self.config.region_name)
filters: List[Dict[str, Any]] = []
prefix: Optional[str] = None
if self.config.scope == SecretsManagerScope.NONE:
# unscoped (legacy) secrets don't have tags. We want to filter out
# non-legacy secrets
filters = [
{
"Key": "tag-key",
"Values": [
"!zenml_scope",
],
},
]
if secret_name:
prefix = secret_name
else:
filters = self._get_secret_scope_filters()
if secret_name:
prefix = self._get_scoped_secret_name(secret_name)
else:
# add the name prefix to the filters to account for the fact
# that AWS does not do exact matching but prefix-matching on the
# filters
prefix = self._get_scoped_secret_name_prefix()
if prefix:
filters.append(
{
"Key": "name",
"Values": [
f"{prefix}",
],
}
)
paginator = self.CLIENT.get_paginator(_BOTO_CLIENT_LIST_SECRETS)
pages = paginator.paginate(
Filters=filters,
PaginationConfig={
"PageSize": 100,
},
)
results = []
for page in pages:
for secret in page[_PAGINATOR_RESPONSE_SECRETS_LIST_KEY]:
name = self._get_unscoped_secret_name(secret["Name"])
# keep only the names that are in scope and filter by secret name,
# if one was given
if name and (not secret_name or secret_name == name):
results.append(name)
return results
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
config: AWSSecretsManagerConfig
property
readonly
Returns the AWSSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
AWSSecretsManagerConfig |
The configuration. |
delete_all_secrets(self)
Delete all existing secrets.
This method will force delete all your secrets. You will not be able to recover them once this method is called.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator)
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
"""
@property
def config(self) -> SagemakerStepOperatorConfig:
"""Returns the `SagemakerStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerStepOperatorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the SageMaker step operator.
Returns:
The settings class.
"""
return SagemakerStepOperatorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the SageMaker "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"SageMaker step operator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_docker_image(
deployment=deployment,
stack=stack,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
for step in steps_to_run:
step.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY] = repo_digest
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY]
environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
session = sagemaker.Session(default_bucket=self.config.bucket)
instance_type = settings.instance_type or "ml.m5.large"
estimator = sagemaker.estimator.Estimator(
image_name,
self.config.role,
environment=environment,
instance_count=1,
instance_type=instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = info.run_name.replace("_", "-")
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
config: SagemakerStepOperatorConfig
property
readonly
Returns the SagemakerStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerStepOperatorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the SageMaker step operator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
launch(self, info, entrypoint_command)
Launches a step on SageMaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY]
environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
session = sagemaker.Session(default_bucket=self.config.bucket)
instance_type = settings.instance_type or "ml.m5.large"
estimator = sagemaker.estimator.Estimator(
image_name,
self.config.role,
environment=environment,
instance_count=1,
instance_type=instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = info.run_name.replace("_", "-")
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
steps_to_run = [
step
for step in deployment.steps.values()
if step.config.step_operator == self.name
]
if not steps_to_run:
return
docker_image_builder = PipelineDockerImageBuilder()
repo_digest = docker_image_builder.build_docker_image(
deployment=deployment,
stack=stack,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
for step in steps_to_run:
step.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY] = repo_digest