Aws
zenml.integrations.aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = [
"sagemaker==2.117.0",
"kubernetes",
"aws-profile-manager",
]
@staticmethod
def activate() -> None:
"""Activate the AWS integration."""
from zenml.integrations.aws import service_connectors # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
activate()
staticmethod
Activate the AWS integration.
Source code in zenml/integrations/aws/__init__.py
@staticmethod
def activate() -> None:
"""Activate the AWS integration."""
from zenml.integrations.aws import service_connectors # noqa
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
AWSSecretsManagerFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSSecretsManagerFlavor,
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
@property
def config(self) -> AWSContainerRegistryConfig:
"""Returns the `AWSContainerRegistryConfig` config.
Returns:
The configuration.
"""
return cast(AWSContainerRegistryConfig, self._config)
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(
r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.config.uri}."
)
return match.group(1)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
# Find repository name from image name
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
try:
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
except NoCredentialsError:
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could not find any repositories because "
"your local AWS credentials are not set. We will try to push "
"anyway, but in case it fails you need to create a repository "
f"named `{repo_name}`."
)
return
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for example run a pipeline "
"using our Kubeflow orchestrator, ZenML will automatically build a "
f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create the "
f"repository `zenml-kubeflow` inside your amazon registry."
)
config: AWSContainerRegistryConfig
property
readonly
Returns the AWSContainerRegistryConfig
config.
Returns:
Type | Description |
---|---|
AWSContainerRegistryConfig |
The configuration. |
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
# Find repository name from image name
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
try:
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
except NoCredentialsError:
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could not find any repositories because "
"your local AWS credentials are not set. We will try to push "
"anyway, but in case it fails you need to create a repository "
f"named `{repo_name}`."
)
return
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
flavors
special
AWS integration flavors.
aws_container_registry_flavor
AWS container registry flavor.
AWSContainerRegistryConfig (BaseContainerRegistryConfig)
pydantic-model
Configuration for AWS Container Registry.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
"""Configuration for AWS Container Registry."""
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)
AWS Container Registry flavor.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
"""AWS Container Registry flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_CONTAINER_REGISTRY_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(
connector_type="aws",
resource_type="docker-registry",
resource_id_attr="uri",
)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png"
@property
def config_class(self) -> Type[AWSContainerRegistryConfig]:
"""Config class for this flavor.
Returns:
The config class.
"""
return AWSContainerRegistryConfig
@property
def implementation_class(self) -> Type["AWSContainerRegistry"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.container_registries import (
AWSContainerRegistry,
)
return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[AWSContainerRegistry]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSContainerRegistry] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
aws_secrets_manager_flavor
AWS secrets manager flavor.
AWSSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the AWS Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
region_name |
str |
The region name of the AWS Secrets Manager. |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the AWS Secrets Manager.
Attributes:
region_name: The region name of the AWS Secrets Manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
region_name: str
@classmethod
def _validate_scope(
cls,
scope: "SecretsManagerScope",
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the AWSSecretsManagerFlavor
.
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `AWSSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return AWS_SECRET_MANAGER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/secrets_managers/aws.png"
@property
def config_class(self) -> Type[AWSSecretsManagerConfig]:
"""Config class for this flavor.
Returns:
Config class for this flavor.
"""
return AWSSecretsManagerConfig
@property
def implementation_class(self) -> Type["AWSSecretsManager"]:
"""Implementation class.
Returns:
Implementation class.
"""
from zenml.integrations.aws.secrets_managers import AWSSecretsManager
return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] |
Config class for this flavor. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[AWSSecretsManager]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSSecretsManager] |
Implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
validate_aws_secret_name_or_namespace(name)
Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The /
character is only used internally to delimit
scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
sagemaker_orchestrator_flavor
Amazon SageMaker orchestrator flavor.
SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings)
pydantic-model
Config for the Sagemaker orchestrator.
There are three ways to authenticate to AWS:
- By connecting a ServiceConnector
to the orchestrator,
- By configuring explicit AWS credentials aws_access_key_id
,
aws_secret_access_key
, and optional aws_auth_role_arn
,
- If none of the above are provided, unspecified credentials will be
loaded from the default AWS config.
Attributes:
Name | Type | Description |
---|---|---|
synchronous |
bool |
Whether to run the processing job synchronously or asynchronously. Defaults to False. |
execution_role |
str |
The IAM role ARN to use for the pipeline. |
aws_access_key_id |
Optional[str] |
The AWS access key ID to use to authenticate to AWS. If not provided, the value from the default AWS config will be used. |
aws_secret_access_key |
Optional[str] |
The AWS secret access key to use to authenticate to AWS. If not provided, the value from the default AWS config will be used. |
aws_profile |
Optional[str] |
The AWS profile to use for authentication if not using service connectors or explicit credentials. If not provided, the default profile will be used. |
aws_auth_role_arn |
Optional[str] |
The ARN of an intermediate IAM role to assume when authenticating to AWS. |
region |
Optional[str] |
The AWS region where the processing job will be run. If not provided, the value from the default AWS config will be used. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
"""Config for the Sagemaker orchestrator.
There are three ways to authenticate to AWS:
- By connecting a `ServiceConnector` to the orchestrator,
- By configuring explicit AWS credentials `aws_access_key_id`,
`aws_secret_access_key`, and optional `aws_auth_role_arn`,
- If none of the above are provided, unspecified credentials will be
loaded from the default AWS config.
Attributes:
synchronous: Whether to run the processing job synchronously or
asynchronously. Defaults to False.
execution_role: The IAM role ARN to use for the pipeline.
aws_access_key_id: The AWS access key ID to use to authenticate to AWS.
If not provided, the value from the default AWS config will be used.
aws_secret_access_key: The AWS secret access key to use to authenticate
to AWS. If not provided, the value from the default AWS config will
be used.
aws_profile: The AWS profile to use for authentication if not using
service connectors or explicit credentials. If not provided, the
default profile will be used.
aws_auth_role_arn: The ARN of an intermediate IAM role to assume when
authenticating to AWS.
region: The AWS region where the processing job will be run. If not
provided, the value from the default AWS config will be used.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format:
"sagemaker-{region}-{aws-account-id}".
"""
synchronous: bool = False
execution_role: str
aws_access_key_id: Optional[str] = SecretField()
aws_secret_access_key: Optional[str] = SecretField()
aws_profile: Optional[str] = None
aws_auth_role_arn: Optional[str] = None
region: Optional[str] = None
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Sagemaker orchestrator.
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Sagemaker orchestrator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(resource_type="aws-generic")
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png"
@property
def config_class(self) -> Type[SagemakerOrchestratorConfig]:
"""Returns SagemakerOrchestratorConfig config class.
Returns:
The config class.
"""
return SagemakerOrchestratorConfig
@property
def implementation_class(self) -> Type["SagemakerOrchestrator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.orchestrators import SagemakerOrchestrator
return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]
property
readonly
Returns SagemakerOrchestratorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[SagemakerOrchestrator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerOrchestrator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
SagemakerOrchestratorSettings (BaseSettings)
pydantic-model
Settings for the Sagemaker orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
instance_type |
str |
The instance type to use for the processing job. |
processor_role |
Optional[str] |
The IAM role to use for the step execution on a Processor. |
volume_size_in_gb |
int |
The size of the EBS volume to use for the processing job. |
max_runtime_in_seconds |
int |
The maximum runtime in seconds for the processing job. |
processor_tags |
Dict[str, str] |
Tags to apply to the Processor assigned to the step. |
processor_args |
Dict[str, Any] |
Arguments that are directly passed to the SageMaker Processor for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor for a full list of arguments. For processor_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types. |
input_data_s3_mode |
str |
How data is made available to the container. Two possible input modes: File, Pipe. |
input_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with input_data_s3_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3. |
output_data_s3_mode |
str |
How data is uploaded to the S3 bucket. Two possible output modes: EndOfJob, Continuous. |
output_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where data is uploaded after or during processing run.
e.g. s3://my-bucket/my-data/output. How data will be made available
to the container is configured with output_data_s3_mode. Two possible
input types:
- str: S3 location where data will be uploaded from a local folder
named /opt/ml/processing/output/data.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. output_one, output_two) where
specific parts of the data are stored locally for S3 upload.
Data must be available locally in /opt/ml/processing/output/data/ |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
"""Settings for the Sagemaker orchestrator.
Attributes:
instance_type: The instance type to use for the processing job.
processor_role: The IAM role to use for the step execution on a Processor.
volume_size_in_gb: The size of the EBS volume to use for the processing
job.
max_runtime_in_seconds: The maximum runtime in seconds for the
processing job.
processor_tags: Tags to apply to the Processor assigned to the step.
processor_args: Arguments that are directly passed to the SageMaker
Processor for a specific step, allowing for overriding the default
settings provided when configuring the component. See
https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor
for a full list of arguments.
For processor_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
input_data_s3_mode: How data is made available to the container.
Two possible input modes: File, Pipe.
input_data_s3_uri: S3 URI where data is located if not locally,
e.g. s3://my-bucket/my-data/train. How data will be made available
to the container is configured with input_data_s3_mode. Two possible
input types:
- str: S3 location where training data is saved.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. training, validation, testing) where
specific parts of the data are saved in S3.
output_data_s3_mode: How data is uploaded to the S3 bucket.
Two possible output modes: EndOfJob, Continuous.
output_data_s3_uri: S3 URI where data is uploaded after or during processing run.
e.g. s3://my-bucket/my-data/output. How data will be made available
to the container is configured with output_data_s3_mode. Two possible
input types:
- str: S3 location where data will be uploaded from a local folder
named /opt/ml/processing/output/data.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. output_one, output_two) where
specific parts of the data are stored locally for S3 upload.
Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
"""
instance_type: str = "ml.t3.medium"
processor_role: Optional[str] = None
volume_size_in_gb: int = 30
max_runtime_in_seconds: int = 86400
processor_tags: Dict[str, str] = {}
processor_args: Dict[str, Any] = {}
input_data_s3_mode: str = "File"
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None
output_data_s3_mode: str = "EndOfJob"
output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None
sagemaker_step_operator_flavor
Amazon SageMaker step operator flavor.
SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings)
pydantic-model
Config for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
"""Config for the Sagemaker step operator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
"""
role: str
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)
Flavor for the Sagemaker step operator.
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
"""Flavor for the Sagemaker step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png"
@property
def config_class(self) -> Type[SagemakerStepOperatorConfig]:
"""Returns SagemakerStepOperatorConfig config class.
Returns:
The config class.
"""
return SagemakerStepOperatorConfig
@property
def implementation_class(self) -> Type["SagemakerStepOperator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.step_operators import SagemakerStepOperator
return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]
property
readonly
Returns SagemakerStepOperatorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[SagemakerStepOperator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerStepOperator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
SagemakerStepOperatorSettings (BaseSettings)
pydantic-model
Settings for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
input_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where training data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with estimator_args.input_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3. |
estimator_args |
Dict[str, Any] |
Arguments that are directly passed to the SageMaker Estimator. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For estimator_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types. |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
"""Settings for the Sagemaker step operator.
Attributes:
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
input_data_s3_uri: S3 URI where training data is located if not locally,
e.g. s3://my-bucket/my-data/train. How data will be made available
to the container is configured with estimator_args.input_mode. Two possible
input types:
- str: S3 location where training data is saved.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. training, validation, testing) where
specific parts of the data are saved in S3.
estimator_args: Arguments that are directly passed to the SageMaker
Estimator. See
https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
for a full list of arguments.
For estimator_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
"""
instance_type: Optional[str] = None
experiment_name: Optional[str] = None
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None
estimator_args: Dict[str, Any] = {}
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
"instance_type"
)
orchestrators
special
AWS Sagemaker orchestrator.
sagemaker_orchestrator
Implementation of the SageMaker orchestrator.
SagemakerOrchestrator (ContainerizedOrchestrator)
Orchestrator responsible for running pipelines on Sagemaker.
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines on Sagemaker."""
@property
def config(self) -> SagemakerOrchestratorConfig:
"""Returns the `SagemakerOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerOrchestratorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
In the remote case, checks that the stack contains a container registry,
image builder and only remote components.
Returns:
A `StackValidator` instance.
"""
def _validate_remote_components(
stack: "Stack",
) -> Tuple[bool, str]:
for component in stack.components.values():
if not component.config.is_local:
continue
return False, (
f"The Sagemaker orchestrator runs pipelines remotely, "
f"but the '{component.name}' {component.type.value} is "
"a local stack component and will not be available in "
"the Sagemaker step.\nPlease ensure that you always "
"use non-local stack components with the Sagemaker "
"orchestrator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Sagemaker orchestrator.
Returns:
The settings class.
"""
return SagemakerOrchestratorSettings
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
environment: Dict[str, str],
) -> None:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
RuntimeError: If a connector is used that does not return a
`boto3.Session` object.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
# sagemaker requires pipelineName to use alphanum and hyphens only
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# replace all non-alphanum and non-hyphens with hyphens
orchestrator_run_name = re.sub(
r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
)
# Get authenticated session
# Option 1: Service connector
boto_session: boto3.Session
if connector := self.get_connector():
boto_session = connector.connect()
if not isinstance(boto_session, boto3.Session):
raise RuntimeError(
f"Expected to receive a `boto3.Session` object from the "
f"linked connector, but got type `{type(boto_session)}`."
)
# Option 2: Explicit configuration
# Args that are not provided will be taken from the default AWS config.
else:
boto_session = boto3.Session(
aws_access_key_id=self.config.aws_access_key_id,
aws_secret_access_key=self.config.aws_secret_access_key,
region_name=self.config.region,
profile_name=self.config.aws_profile,
)
# If a role ARN is provided for authentication, assume the role
if self.config.aws_auth_role_arn:
sts = boto_session.client("sts")
response = sts.assume_role(
RoleArn=self.config.aws_auth_role_arn,
RoleSessionName="zenml-sagemaker-orchestrator",
)
credentials = response["Credentials"]
boto_session = boto3.Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.config.region,
)
session = sagemaker.Session(
boto_session=boto_session, default_bucket=self.config.bucket
)
sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
environment[
ENV_ZENML_SAGEMAKER_RUN_ID
] = ExecutionVariables.PIPELINE_EXECUTION_ARN
# Retrieve Processor arguments provided in the Step settings.
processor_args_for_step = step_settings.processor_args or {}
# Set default values from configured orchestrator Component to arguments
# to be used when they are not present in processor_args.
processor_args_for_step.setdefault(
"instance_type", step_settings.instance_type
)
processor_args_for_step.setdefault(
"role",
step_settings.processor_role or self.config.execution_role,
)
processor_args_for_step.setdefault(
"volume_size_in_gb", step_settings.volume_size_in_gb
)
processor_args_for_step.setdefault(
"max_runtime_in_seconds", step_settings.max_runtime_in_seconds
)
processor_args_for_step.setdefault(
"tags",
[
{"Key": key, "Value": value}
for key, value in step_settings.processor_tags.items()
]
if step_settings.processor_tags
else None,
)
# Set values that cannot be overwritten
processor_args_for_step["image_uri"] = image
processor_args_for_step["instance_count"] = 1
processor_args_for_step["sagemaker_session"] = session
processor_args_for_step["entrypoint"] = entrypoint
processor_args_for_step["base_job_name"] = orchestrator_run_name
processor_args_for_step["env"] = environment
# Construct S3 inputs to container for step
inputs = None
if step_settings.input_data_s3_uri is None:
pass
elif isinstance(step_settings.input_data_s3_uri, str):
inputs = [
ProcessingInput(
source=step_settings.input_data_s3_uri,
destination="/opt/ml/processing/input/data",
s3_input_mode=step_settings.input_data_s3_mode,
)
]
elif isinstance(step_settings.input_data_s3_uri, dict):
inputs = []
for channel, s3_uri in step_settings.input_data_s3_uri.items():
inputs.append(
ProcessingInput(
source=s3_uri,
destination=f"/opt/ml/processing/input/data/{channel}",
s3_input_mode=step_settings.input_data_s3_mode,
)
)
# Construct S3 outputs from container for step
outputs = None
if step_settings.output_data_s3_uri is None:
pass
elif isinstance(step_settings.output_data_s3_uri, str):
outputs = [
ProcessingOutput(
source="/opt/ml/processing/output/data",
destination=step_settings.output_data_s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
]
elif isinstance(step_settings.output_data_s3_uri, dict):
outputs = []
for (
channel,
s3_uri,
) in step_settings.output_data_s3_uri.items():
outputs.append(
ProcessingOutput(
source=f"/opt/ml/processing/output/data/{channel}",
destination=s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
)
# Create Processor and ProcessingStep
processor = sagemaker.processing.Processor(
**processor_args_for_step
)
sagemaker_step = ProcessingStep(
name=step_name,
processor=processor,
depends_on=step.spec.upstream_steps,
inputs=inputs,
outputs=outputs,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
pipeline.create(role_arn=self.config.execution_role)
pipeline_execution = pipeline.start()
logger.warning(
"Steps can take 5-15 minutes to start running "
"when using the Sagemaker Orchestrator."
)
# mainly for testing purposes, we wait for the pipeline to finish
if self.config.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish..."
)
pipeline_execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
def _get_region_name(self) -> str:
"""Returns the AWS region name.
Returns:
The region name.
Raises:
RuntimeError: If the region name cannot be retrieved.
"""
try:
return cast(str, sagemaker.Session().boto_region_name)
except Exception as e:
raise RuntimeError(
"Unable to get region name. Please ensure that you have "
"configured your AWS credentials correctly."
) from e
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_metadata: Dict[str, "MetadataType"] = {
"pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
}
try:
region_name = self._get_region_name()
except RuntimeError:
logger.warning("Unable to get region name from AWS Sagemaker.")
return run_metadata
aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
orchestrator_logs_url = (
f"https://{region_name}.console.aws.amazon.com/"
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
f"$3Dpipelines-{aws_run_id}-"
)
run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
return run_metadata
config: SagemakerOrchestratorConfig
property
readonly
Returns the SagemakerOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerOrchestratorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Sagemaker orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
In the remote case, checks that the stack contains a container registry, image builder and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
get_orchestrator_run_id(self)
Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run id cannot be read from the environment. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
get_pipeline_run_metadata(self, run_id)
Get general component-specific metadata for a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
The ID of the pipeline run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_metadata: Dict[str, "MetadataType"] = {
"pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
}
try:
region_name = self._get_region_name()
except RuntimeError:
logger.warning("Unable to get region name from AWS Sagemaker.")
return run_metadata
aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
orchestrator_logs_url = (
f"https://{region_name}.console.aws.amazon.com/"
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
f"$3Dpipelines-{aws_run_id}-"
)
run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
return run_metadata
prepare_or_run_pipeline(self, deployment, stack, environment)
Prepares or runs a pipeline on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponseModel |
The deployment to prepare or run. |
required |
stack |
Stack |
The stack to run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If a connector is used that does not return a
|
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
environment: Dict[str, str],
) -> None:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
RuntimeError: If a connector is used that does not return a
`boto3.Session` object.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
# sagemaker requires pipelineName to use alphanum and hyphens only
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# replace all non-alphanum and non-hyphens with hyphens
orchestrator_run_name = re.sub(
r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
)
# Get authenticated session
# Option 1: Service connector
boto_session: boto3.Session
if connector := self.get_connector():
boto_session = connector.connect()
if not isinstance(boto_session, boto3.Session):
raise RuntimeError(
f"Expected to receive a `boto3.Session` object from the "
f"linked connector, but got type `{type(boto_session)}`."
)
# Option 2: Explicit configuration
# Args that are not provided will be taken from the default AWS config.
else:
boto_session = boto3.Session(
aws_access_key_id=self.config.aws_access_key_id,
aws_secret_access_key=self.config.aws_secret_access_key,
region_name=self.config.region,
profile_name=self.config.aws_profile,
)
# If a role ARN is provided for authentication, assume the role
if self.config.aws_auth_role_arn:
sts = boto_session.client("sts")
response = sts.assume_role(
RoleArn=self.config.aws_auth_role_arn,
RoleSessionName="zenml-sagemaker-orchestrator",
)
credentials = response["Credentials"]
boto_session = boto3.Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.config.region,
)
session = sagemaker.Session(
boto_session=boto_session, default_bucket=self.config.bucket
)
sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
environment[
ENV_ZENML_SAGEMAKER_RUN_ID
] = ExecutionVariables.PIPELINE_EXECUTION_ARN
# Retrieve Processor arguments provided in the Step settings.
processor_args_for_step = step_settings.processor_args or {}
# Set default values from configured orchestrator Component to arguments
# to be used when they are not present in processor_args.
processor_args_for_step.setdefault(
"instance_type", step_settings.instance_type
)
processor_args_for_step.setdefault(
"role",
step_settings.processor_role or self.config.execution_role,
)
processor_args_for_step.setdefault(
"volume_size_in_gb", step_settings.volume_size_in_gb
)
processor_args_for_step.setdefault(
"max_runtime_in_seconds", step_settings.max_runtime_in_seconds
)
processor_args_for_step.setdefault(
"tags",
[
{"Key": key, "Value": value}
for key, value in step_settings.processor_tags.items()
]
if step_settings.processor_tags
else None,
)
# Set values that cannot be overwritten
processor_args_for_step["image_uri"] = image
processor_args_for_step["instance_count"] = 1
processor_args_for_step["sagemaker_session"] = session
processor_args_for_step["entrypoint"] = entrypoint
processor_args_for_step["base_job_name"] = orchestrator_run_name
processor_args_for_step["env"] = environment
# Construct S3 inputs to container for step
inputs = None
if step_settings.input_data_s3_uri is None:
pass
elif isinstance(step_settings.input_data_s3_uri, str):
inputs = [
ProcessingInput(
source=step_settings.input_data_s3_uri,
destination="/opt/ml/processing/input/data",
s3_input_mode=step_settings.input_data_s3_mode,
)
]
elif isinstance(step_settings.input_data_s3_uri, dict):
inputs = []
for channel, s3_uri in step_settings.input_data_s3_uri.items():
inputs.append(
ProcessingInput(
source=s3_uri,
destination=f"/opt/ml/processing/input/data/{channel}",
s3_input_mode=step_settings.input_data_s3_mode,
)
)
# Construct S3 outputs from container for step
outputs = None
if step_settings.output_data_s3_uri is None:
pass
elif isinstance(step_settings.output_data_s3_uri, str):
outputs = [
ProcessingOutput(
source="/opt/ml/processing/output/data",
destination=step_settings.output_data_s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
]
elif isinstance(step_settings.output_data_s3_uri, dict):
outputs = []
for (
channel,
s3_uri,
) in step_settings.output_data_s3_uri.items():
outputs.append(
ProcessingOutput(
source=f"/opt/ml/processing/output/data/{channel}",
destination=s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
)
# Create Processor and ProcessingStep
processor = sagemaker.processing.Processor(
**processor_args_for_step
)
sagemaker_step = ProcessingStep(
name=step_name,
processor=processor,
depends_on=step.spec.upstream_steps,
inputs=inputs,
outputs=outputs,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
pipeline.create(role_arn=self.config.execution_role)
pipeline_execution = pipeline.start()
logger.warning(
"Steps can take 5-15 minutes to start running "
"when using the Sagemaker Orchestrator."
)
# mainly for testing purposes, we wait for the pipeline to finish
if self.config.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish..."
)
pipeline_execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
secrets_managers
special
AWS Secrets Manager.
aws_secrets_manager
Implementation of the AWS Secrets Manager integration.
AWSSecretsManager (BaseSecretsManager)
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> AWSSecretsManagerConfig:
"""Returns the `AWSSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(AWSSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
"""Ensure that the client is connected to the AWS secrets manager.
Args:
region_name: the AWS region name
"""
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
def _get_secret_tags(
self, secret: BaseSecretSchema
) -> List[Dict[str, str]]:
"""Return a list of AWS secret tag values for a given secret.
Args:
secret: the secret object
Returns:
A list of AWS secret tag values
"""
metadata = self._get_secret_metadata(secret)
return [{"Key": k, "Value": v} for k, v in metadata.items()]
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return a list of AWS filters for the entire scope or just a scoped secret.
These filters can be used when querying the AWS Secrets Manager
for all secrets or for a single secret available in the configured
scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html
Example AWS filters for all secrets in the current (namespace) scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Example AWS filters for a particular secret in the current (namespace)
scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_secret_name"],
},
{
"Key: "tag-value",
"Values": ["my_secret"],
},
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of AWS filters uniquely identifying all secrets
or a named secret within the configured scope.
"""
metadata = self._get_secret_scope_metadata(secret_name)
filters: List[Dict[str, Any]] = []
for k, v in metadata.items():
filters.append(
{
"Key": "tag-key",
"Values": [
k,
],
}
)
filters.append(
{
"Key": "tag-value",
"Values": [
str(v),
],
}
)
return filters
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected(self.config.region_name)
filters: List[Dict[str, Any]] = []
prefix: Optional[str] = None
if self.config.scope == SecretsManagerScope.NONE:
# unscoped (legacy) secrets don't have tags. We want to filter out
# non-legacy secrets
filters = [
{
"Key": "tag-key",
"Values": [
"!zenml_scope",
],
},
]
if secret_name:
prefix = secret_name
else:
filters = self._get_secret_scope_filters()
if secret_name:
prefix = self._get_scoped_secret_name(secret_name)
else:
# add the name prefix to the filters to account for the fact
# that AWS does not do exact matching but prefix-matching on the
# filters
prefix = self._get_scoped_secret_name_prefix()
if prefix:
filters.append(
{
"Key": "name",
"Values": [
f"{prefix}",
],
}
)
paginator = self.CLIENT.get_paginator(_BOTO_CLIENT_LIST_SECRETS)
pages = paginator.paginate(
Filters=filters,
PaginationConfig={
"PageSize": 100,
},
)
results = []
for page in pages:
for secret in page[_PAGINATOR_RESPONSE_SECRETS_LIST_KEY]:
name = self._get_unscoped_secret_name(secret["Name"])
# keep only the names that are in scope and filter by secret name,
# if one was given
if name and (not secret_name or secret_name == name):
results.append(name)
return results
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
config: AWSSecretsManagerConfig
property
readonly
Returns the AWSSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
AWSSecretsManagerConfig |
The configuration. |
delete_all_secrets(self)
Delete all existing secrets.
This method will force delete all your secrets. You will not be able to recover them once this method is called.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.config.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_aws_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.config.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
service_connectors
special
AWS Service Connector.
aws_service_connector
AWS Service Connector.
The AWS Service Connector implements various authentication methods for AWS services:
- Explicit AWS secret key (access key, secret key)
- Explicit AWS STS tokens (access key, secret key, session token)
- IAM roles (i.e. generating temporary STS tokens on the fly by assuming an IAM role)
- IAM user federation tokens
- STS Session tokens
AWSAuthenticationMethods (StrEnum)
AWS Authentication methods.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSAuthenticationMethods(StrEnum):
"""AWS Authentication methods."""
IMPLICIT = "implicit"
SECRET_KEY = "secret-key"
STS_TOKEN = "sts-token"
IAM_ROLE = "iam-role"
SESSION_TOKEN = "session-token"
FEDERATION_TOKEN = "federation-token"
AWSBaseConfig (AuthenticationConfig)
pydantic-model
AWS base configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSBaseConfig(AuthenticationConfig):
"""AWS base configuration."""
region: str = Field(
title="AWS Region",
)
endpoint_url: Optional[str] = Field(
default=None,
title="AWS Endpoint URL",
)
AWSImplicitConfig (AWSBaseConfig)
pydantic-model
AWS implicit configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSImplicitConfig(AWSBaseConfig):
"""AWS implicit configuration."""
profile_name: Optional[str] = Field(
default=None,
title="AWS Profile Name",
)
AWSSecretKey (AuthenticationConfig)
pydantic-model
AWS secret key credentials.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKey(AuthenticationConfig):
"""AWS secret key credentials."""
aws_access_key_id: SecretStr = Field(
title="AWS Access Key ID",
description="An AWS access key ID associated with an AWS account or IAM user.",
)
aws_secret_access_key: SecretStr = Field(
title="AWS Secret Access Key",
)
aws_access_key_id: SecretStr
pydantic-field
required
An AWS access key ID associated with an AWS account or IAM user.
AWSSecretKeyConfig (AWSBaseConfig, AWSSecretKey)
pydantic-model
AWS secret key authentication configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKeyConfig(AWSBaseConfig, AWSSecretKey):
"""AWS secret key authentication configuration."""
AWSServiceConnector (ServiceConnector)
pydantic-model
AWS service connector.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSServiceConnector(ServiceConnector):
"""AWS service connector."""
config: AWSBaseConfig
_account_id: Optional[str] = None
_session_cache: Dict[
Tuple[str, Optional[str], Optional[str]],
Tuple[boto3.Session, Optional[datetime.datetime]],
] = {}
@classmethod
def _get_connector_type(cls) -> ServiceConnectorTypeModel:
"""Get the service connector type specification.
Returns:
The service connector type specification.
"""
return AWS_SERVICE_CONNECTOR_TYPE_SPEC
@property
def account_id(self) -> str:
"""Get the AWS account ID.
Returns:
The AWS account ID.
Raises:
AuthorizationException: If the AWS account ID could not be
determined.
"""
if self._account_id is None:
logger.debug("Getting account ID from AWS...")
try:
session, _ = self.get_boto3_session(self.auth_method)
sts_client = session.client("sts")
response = sts_client.get_caller_identity()
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to fetch the AWS account ID: {e}"
) from e
self._account_id = response["Account"]
return self._account_id
def get_boto3_session(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
"""Get a boto3 session for the specified resource.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to get a boto3 session for.
resource_id: The resource ID to get a boto3 session for.
Returns:
A boto3 session for the specified resource and its expiration
timestamp, if applicable.
"""
# We maintain a cache of all sessions to avoid re-authenticating
# multiple times for the same resource
key = (auth_method, resource_type, resource_id)
if key in self._session_cache:
session, expires_at = self._session_cache[key]
if expires_at is None:
return session, None
# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
if expires_at > now:
return session, expires_at
logger.debug(
f"Creating boto3 session for auth method '{auth_method}', "
f"resource type '{resource_type}' and resource ID "
f"'{resource_id}'..."
)
session, expires_at = self._authenticate(
auth_method, resource_type, resource_id
)
self._session_cache[key] = (session, expires_at)
return session, expires_at
def _get_iam_policy(
self,
region_id: str,
resource_type: Optional[str],
resource_id: Optional[str] = None,
) -> Optional[str]:
"""Get the IAM inline policy to use for the specified resource.
Args:
region_id: The AWS region ID to get the IAM inline policy for.
resource_type: The resource type to get the IAM inline policy for.
resource_id: The resource ID to get the IAM inline policy for.
Returns:
The IAM inline policy to use for the specified resource.
"""
if resource_type == S3_RESOURCE_TYPE:
if resource_id:
bucket = self._parse_s3_resource_id(resource_id)
resource = [
f"arn:aws:s3:::{bucket}",
f"arn:aws:s3:::{bucket}/*",
]
else:
resource = ["arn:aws:s3:::*", "arn:aws:s3:::*/*"]
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowS3BucketAccess",
"Effect": "Allow",
"Action": [
"s3:ListBucket",
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListAllMyBuckets",
],
"Resource": resource,
},
],
}
return json.dumps(policy)
elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
if resource_id:
cluster_name = self._parse_eks_resource_id(resource_id)
resource = [
f"arn:aws:eks:{region_id}:*:cluster/{cluster_name}",
]
else:
resource = [f"arn:aws:eks:{region_id}:*:cluster/*"]
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowEKSClusterAccess",
"Effect": "Allow",
"Action": [
"eks:ListClusters",
"eks:DescribeCluster",
],
"Resource": resource,
},
],
}
return json.dumps(policy)
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
resource = [
f"arn:aws:ecr:{region_id}:*:repository/*",
f"arn:aws:ecr:{region_id}:*:repository",
]
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowECRRepositoryAccess",
"Effect": "Allow",
"Action": [
"ecr:DescribeRegistry",
"ecr:DescribeRepositories",
"ecr:ListRepositories",
"ecr:BatchGetImage",
"ecr:DescribeImages",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:InitiateLayerUpload",
"ecr:UploadLayerPart",
"ecr:CompleteLayerUpload",
"ecr:PutImage",
],
"Resource": resource,
},
{
"Sid": "AllowECRRepositoryGetToken",
"Effect": "Allow",
"Action": [
"ecr:GetAuthorizationToken",
],
"Resource": ["*"],
},
],
}
return json.dumps(policy)
return None
def _authenticate(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
"""Authenticate to AWS and return a boto3 session.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to authenticate for.
resource_id: The resource ID to authenticate for.
Returns:
An authenticated boto3 session and the expiration time of the
temporary credentials if applicable.
Raises:
AuthorizationException: If the IAM role authentication method is
used and the role cannot be assumed.
NotImplementedError: If the authentication method is not supported.
"""
cfg = self.config
if auth_method == AWSAuthenticationMethods.IMPLICIT:
self._check_implicit_auth_method_allowed()
assert isinstance(cfg, AWSImplicitConfig)
# Create a boto3 session and use the default credentials provider
session = boto3.Session(
profile_name=cfg.profile_name, region_name=cfg.region
)
credentials = session.get_credentials()
if not credentials:
raise AuthorizationException(
"Failed to get AWS credentials from the default provider. "
"Please check your AWS configuration or attached IAM role."
)
if credentials.token:
# Temporary credentials were generated. It's not possible to
# determine the expiration time of the temporary credentials
# from the boto3 session, so we assume the default IAM role
# expiration date is used
expiration_time = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
)
return session, expiration_time
return session, None
elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
assert isinstance(cfg, AWSSecretKeyConfig)
# Create a boto3 session using long-term AWS credentials
session = boto3.Session(
aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
region_name=cfg.region,
)
return session, None
elif auth_method == AWSAuthenticationMethods.STS_TOKEN:
assert isinstance(cfg, STSTokenConfig)
# Create a boto3 session using a temporary AWS STS token
session = boto3.Session(
aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
aws_session_token=cfg.aws_session_token.get_secret_value(),
region_name=cfg.region,
)
return session, cfg.expires_at
elif auth_method in [
AWSAuthenticationMethods.IAM_ROLE,
AWSAuthenticationMethods.SESSION_TOKEN,
AWSAuthenticationMethods.FEDERATION_TOKEN,
]:
assert isinstance(cfg, AWSSecretKey)
# Create a boto3 session
session = boto3.Session(
aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
region_name=cfg.region,
)
sts = session.client("sts", region_name=cfg.region)
session_name = "zenml-connector"
if self.id:
session_name += f"-{self.id}"
# Next steps are different for each authentication method
# The IAM role and federation token authentication methods
# accept a managed IAM policy that restricts/grants permissions.
# If one isn't explicitly configured, we generate one based on the
# resource specified by the resource type and ID (if present).
if auth_method in [
AWSAuthenticationMethods.IAM_ROLE,
AWSAuthenticationMethods.FEDERATION_TOKEN,
]:
assert isinstance(cfg, AWSSessionPolicy)
policy_kwargs = {}
policy = cfg.policy
if not cfg.policy and not cfg.policy_arns:
policy = self._get_iam_policy(
region_id=cfg.region,
resource_type=resource_type,
resource_id=resource_id,
)
if policy:
policy_kwargs["Policy"] = policy
elif cfg.policy_arns:
policy_kwargs["PolicyArns"] = cfg.policy_arns
if auth_method == AWSAuthenticationMethods.IAM_ROLE:
assert isinstance(cfg, IAMRoleAuthenticationConfig)
try:
response = sts.assume_role(
RoleArn=cfg.role_arn,
RoleSessionName=session_name,
DurationSeconds=self.expiration_seconds,
**policy_kwargs,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to assume IAM role {cfg.role_arn} "
f"using the AWS credentials configured in the "
f"connector: {e}"
) from e
else:
assert isinstance(cfg, FederationTokenAuthenticationConfig)
try:
response = sts.get_federation_token(
Name=session_name[:32],
DurationSeconds=self.expiration_seconds,
**policy_kwargs,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
"Failed to get federation token "
"using the AWS credentials configured in the "
f"connector: {e}"
) from e
else:
assert isinstance(cfg, SessionTokenAuthenticationConfig)
try:
response = sts.get_session_token(
DurationSeconds=self.expiration_seconds,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
"Failed to get session token "
"using the AWS credentials configured in the "
f"connector: {e}"
) from e
session = boto3.Session(
aws_access_key_id=response["Credentials"]["AccessKeyId"],
aws_secret_access_key=response["Credentials"][
"SecretAccessKey"
],
aws_session_token=response["Credentials"]["SessionToken"],
)
expiration = response["Credentials"]["Expiration"]
# Add the UTC timezone to the expiration time
expiration = expiration.replace(tzinfo=datetime.timezone.utc)
return session, expiration
raise NotImplementedError(
f"Authentication method '{auth_method}' is not supported by "
"the AWS connector."
)
@classmethod
def _get_eks_bearer_token(
cls,
session: boto3.Session,
cluster_id: str,
region: str,
) -> str:
"""Generate a bearer token for authenticating to the EKS API server.
Based on: https://github.com/kubernetes-sigs/aws-iam-authenticator/blob/master/README.md#api-authorization-from-outside-a-cluster
Args:
session: An authenticated boto3 session to use for generating the
token.
cluster_id: The name of the EKS cluster.
region: The AWS region the EKS cluster is in.
Returns:
A bearer token for authenticating to the EKS API server.
"""
client = session.client("sts", region_name=region)
service_id = client.meta.service_model.service_id
signer = RequestSigner(
service_id,
region,
"sts",
"v4",
session.get_credentials(),
session.events,
)
params = {
"method": "GET",
"url": f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15",
"body": {},
"headers": {"x-k8s-aws-id": cluster_id},
"context": {},
}
signed_url = signer.generate_presigned_url(
params,
region_name=region,
expires_in=EKS_KUBE_API_TOKEN_EXPIRATION,
operation_name="",
)
base64_url = base64.urlsafe_b64encode(
signed_url.encode("utf-8")
).decode("utf-8")
# remove any base64 encoding padding:
return "k8s-aws-v1." + re.sub(r"=*", "", base64_url)
def _parse_s3_resource_id(self, resource_id: str) -> str:
"""Validate and convert an S3 resource ID to an S3 bucket name.
Args:
resource_id: The resource ID to convert.
Returns:
The S3 bucket name.
Raises:
ValueError: If the provided resource ID is not a valid S3 bucket
name, ARN or URI.
"""
# The resource ID could mean different things:
#
# - an S3 bucket ARN
# - an S3 bucket URI
# - the S3 bucket name
#
# We need to extract the bucket name from the provided resource ID
bucket_name: Optional[str] = None
if re.match(
r"^arn:aws:s3:::[a-z0-9-]+(/.*)*$",
resource_id,
):
# The resource ID is an S3 bucket ARN
bucket_name = resource_id.split(":")[-1].split("/")[0]
elif re.match(
r"^s3://[a-z0-9-]+(/.*)*$",
resource_id,
):
# The resource ID is an S3 bucket URI
bucket_name = resource_id.split("/")[2]
elif re.match(
r"^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$",
resource_id,
):
# The resource ID is the S3 bucket name
bucket_name = resource_id
else:
raise ValueError(
f"Invalid resource ID for an S3 bucket: {resource_id}. "
f"Supported formats are:\n"
f"S3 bucket ARN: arn:aws:s3:::<bucket-name>\n"
f"S3 bucket URI: s3://<bucket-name>\n"
f"S3 bucket name: <bucket-name>"
)
return bucket_name
def _parse_ecr_resource_id(
self,
resource_id: str,
) -> str:
"""Validate and convert an ECR resource ID to an ECR registry ID.
Args:
resource_id: The resource ID to convert.
Returns:
The ECR registry ID (AWS account ID).
Raises:
ValueError: If the provided resource ID is not a valid ECR
repository ARN or URI.
"""
# The resource ID could mean different things:
#
# - an ECR repository ARN
# - an ECR repository URI
#
# We need to extract the region ID and registry ID from
# the provided resource ID
config_region_id = self.config.region
region_id: Optional[str] = None
if re.match(
r"^arn:aws:ecr:[a-z0-9-]+:\d{12}:repository(/.+)*$",
resource_id,
):
# The resource ID is an ECR repository ARN
registry_id = resource_id.split(":")[4]
region_id = resource_id.split(":")[3]
elif re.match(
r"^(http[s]?://)?\d{12}\.dkr\.ecr\.[a-z0-9-]+\.amazonaws\.com(/.+)*$",
resource_id,
):
# The resource ID is an ECR repository URI
registry_id = resource_id.split(".")[0].split("/")[-1]
region_id = resource_id.split(".")[3]
else:
raise ValueError(
f"Invalid resource ID for a ECR registry: {resource_id}. "
f"Supported formats are:\n"
f"ECR repository ARN: arn:aws:ecr:<region>:<account-id>:repository[/<repository-name>]\n"
f"ECR repository URI: [https://]<account-id>.dkr.ecr.<region>.amazonaws.com[/<repository-name>]"
)
# If the connector is configured with a region and the resource ID
# is an ECR repository ARN or URI that specifies a different region
# we raise an error
if region_id and region_id != config_region_id:
raise ValueError(
f"The AWS region for the {resource_id} ECR repository region "
f"'{region_id}' does not match the region configured in "
f"the connector: '{config_region_id}'."
)
return registry_id
def _parse_eks_resource_id(self, resource_id: str) -> str:
"""Validate and convert an EKS resource ID to an AWS region and EKS cluster name.
Args:
resource_id: The resource ID to convert.
Returns:
The EKS cluster name.
Raises:
ValueError: If the provided resource ID is not a valid EKS cluster
name or ARN.
"""
# The resource ID could mean different things:
#
# - an EKS cluster ARN
# - an EKS cluster ID
#
# We need to extract the cluster name and region ID from the
# provided resource ID
config_region_id = self.config.region
cluster_name: Optional[str] = None
region_id: Optional[str] = None
if re.match(
r"^arn:aws:eks:[a-z0-9-]+:\d{12}:cluster/.+$",
resource_id,
):
# The resource ID is an EKS cluster ARN
cluster_name = resource_id.split("/")[-1]
region_id = resource_id.split(":")[3]
elif re.match(
r"^[a-z0-9]+[a-z0-9_-]*$",
resource_id,
):
# Assume the resource ID is an EKS cluster name
cluster_name = resource_id
else:
raise ValueError(
f"Invalid resource ID for a EKS cluster: {resource_id}. "
f"Supported formats are:\n"
f"EKS cluster ARN: arn:aws:eks:<region>:<account-id>:cluster/<cluster-name>\n"
f"ECR cluster name: <cluster-name>"
)
# If the connector is configured with a region and the resource ID
# is an EKS registry ARN or URI that specifies a different region
# we raise an error
if region_id and region_id != config_region_id:
raise ValueError(
f"The AWS region for the {resource_id} EKS cluster "
f"({region_id}) does not match the region configured in "
f"the connector ({config_region_id})."
)
return cluster_name
def _canonical_resource_id(
self, resource_type: str, resource_id: str
) -> str:
"""Convert a resource ID to its canonical form.
Args:
resource_type: The resource type to canonicalize.
resource_id: The resource ID to canonicalize.
Returns:
The canonical resource ID.
"""
if resource_type == S3_RESOURCE_TYPE:
bucket = self._parse_s3_resource_id(resource_id)
return f"s3://{bucket}"
elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
cluster_name = self._parse_eks_resource_id(resource_id)
return cluster_name
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
registry_id = self._parse_ecr_resource_id(
resource_id,
)
return f"{registry_id}.dkr.ecr.{self.config.region}.amazonaws.com"
else:
return resource_id
def _get_default_resource_id(self, resource_type: str) -> str:
"""Get the default resource ID for a resource type.
Args:
resource_type: The type of the resource to get a default resource ID
for. Only called with resource types that do not support
multiple instances.
Returns:
The default resource ID for the resource type.
Raises:
RuntimeError: If the ECR registry ID (AWS account ID)
cannot be retrieved from AWS because the connector is not
authorized.
"""
if resource_type == AWS_RESOURCE_TYPE:
return self.config.region
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
# we need to get the account ID (same as registry ID) from the
# caller identity AWS service
account_id = self.account_id
return f"{account_id}.dkr.ecr.{self.config.region}.amazonaws.com"
raise RuntimeError(
f"Default resource ID for '{resource_type}' not available."
)
def _connect_to_resource(
self,
**kwargs: Any,
) -> Any:
"""Authenticate and connect to an AWS resource.
Initialize and return a session or client object depending on the
connector configuration:
- initialize and return a boto3 session if the resource type
is a generic AWS resource
- initialize and return a boto3 client for an S3 resource type
For the Docker and Kubernetes resource types, the connector does not
support connecting to the resource directly. Instead, the connector
supports generating a connector client object for the resource type
in question.
Args:
kwargs: Additional implementation specific keyword arguments to pass
to the session or client constructor.
Returns:
A boto3 session for AWS generic resources and a boto3 S3 client for
S3 resources.
Raises:
NotImplementedError: If the connector instance does not support
directly connecting to the indicated resource type.
"""
resource_type = self.resource_type
resource_id = self.resource_id
assert resource_type is not None
assert resource_id is not None
# Regardless of the resource type, we must authenticate to AWS first
# before we can connect to any AWS resource
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
if resource_type == S3_RESOURCE_TYPE:
# Validate that the resource ID is a valid S3 bucket name
self._parse_s3_resource_id(resource_id)
# Create an S3 client for the bucket
client = session.client(
"s3",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
# There is no way to retrieve the credentials from the S3 client
# but some consumers need them to configure 3rd party services.
# We therefore store the credentials in the client object so that
# they can be retrieved later.
client.credentials = session.get_credentials()
return client
if resource_type == AWS_RESOURCE_TYPE:
return session
raise NotImplementedError(
f"Connecting to {resource_type} resources is not directly "
"supported by the AWS connector. Please call the "
f"`get_connector_client` method to get a {resource_type} connector "
"instance for the resource."
)
def _configure_local_client(
self,
profile_name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Configure a local client to authenticate and connect to a resource.
This method uses the connector's configuration to configure a local
client or SDK installed on the localhost for the indicated resource.
Args:
profile_name: The name of the AWS profile to use. If not specified,
a profile name is generated based on the first 8 digits of the
connector's UUID in the form 'zenml-<uuid[:8]>'. If a profile
with the given or generated name already exists, the profile is
overwritten.
kwargs: Additional implementation specific keyword arguments to use
to configure the client.
Raises:
NotImplementedError: If the connector instance does not support
local configuration for the configured resource type or
authentication method.registry
"""
resource_type = self.resource_type
if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=self.resource_id,
)
# Configure a new AWS SDK profile with the credentials
# from the session using the aws-profile-manager package
# Generate a profile name based on the first 8 digits from the
# connector UUID, if one is not supplied
aws_profile_name = profile_name or f"zenml-{str(self.id)[:8]}"
common = Common()
users_home = common.get_users_home()
all_profiles = common.get_all_profiles(users_home)
credentials = session.get_credentials()
all_profiles[aws_profile_name] = {
"region": self.config.region,
"aws_access_key_id": credentials.access_key,
"aws_secret_access_key": credentials.secret_key,
}
if credentials.token:
all_profiles[aws_profile_name][
"aws_session_token"
] = credentials.token
common.rewrite_credentials_file(all_profiles, users_home)
logger.info(
f"Configured local AWS SDK profile '{aws_profile_name}'."
)
return
raise NotImplementedError(
f"Configuring the local client for {resource_type} resources is "
"not directly supported by the AWS connector. Please call the "
f"`get_connector_client` method to get a {resource_type} connector "
"instance for the resource."
)
@classmethod
def _auto_configure(
cls,
auth_method: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
region_name: Optional[str] = None,
profile_name: Optional[str] = None,
role_arn: Optional[str] = None,
**kwargs: Any,
) -> "AWSServiceConnector":
"""Auto-configure the connector.
Instantiate an AWS connector with a configuration extracted from the
authentication configuration available in the environment (e.g.
environment variables or local AWS client/SDK configuration files).
Args:
auth_method: The particular authentication method to use. If not
specified, the connector implementation must decide which
authentication method to use or raise an exception.
resource_type: The type of resource to configure.
resource_id: The ID of the resource to configure. The
implementation may choose to either require or ignore this
parameter if it does not support or detect an resource type that
supports multiple instances.
region_name: The name of the AWS region to use. If not specified,
the implicit region is used.
profile_name: The name of the AWS profile to use. If not specified,
the implicit profile is used.
role_arn: The ARN of the AWS role to assume. Applicable only if the
IAM role authentication method is specified or long-term
credentials are discovered.
kwargs: Additional implementation specific keyword arguments to use.
Returns:
An AWS connector instance configured with authentication credentials
automatically extracted from the environment.
Raises:
NotImplementedError: If the connector implementation does not
support auto-configuration for the specified authentication
method.
ValueError: If the supplied arguments are not valid.
AuthorizationException: If no AWS credentials can be loaded from
the environment.
"""
auth_config: AWSBaseConfig
expiration_seconds: Optional[int] = None
expires_at: Optional[datetime.datetime] = None
if auth_method == AWSAuthenticationMethods.IMPLICIT:
cls._check_implicit_auth_method_allowed()
if region_name is None:
raise ValueError(
"The AWS region name must be specified when using the "
"implicit authentication method"
)
auth_config = AWSImplicitConfig(
profile_name=profile_name,
region=region_name,
)
else:
# Initialize an AWS session with the default configuration loaded
# from the environment.
session = boto3.Session(
profile_name=profile_name, region_name=region_name
)
region_name = region_name or session.region_name
if not region_name:
raise ValueError(
"The AWS region name was not specified and could not "
"be determined from the AWS session"
)
endpoint_url = session._session.get_config_variable("endpoint_url")
# Extract the AWS credentials from the session and store them in
# the connector secrets.
credentials = session.get_credentials()
if not credentials:
raise AuthorizationException(
"Could not determine the AWS credentials from the "
"environment"
)
if credentials.token:
# The session picked up temporary STS credentials
if auth_method and auth_method not in [
None,
AWSAuthenticationMethods.STS_TOKEN,
AWSAuthenticationMethods.IAM_ROLE,
]:
raise NotImplementedError(
f"The specified authentication method '{auth_method}' "
"could not be used to auto-configure the connector. "
)
if (
credentials.method == "assume-role"
and auth_method != AWSAuthenticationMethods.STS_TOKEN
):
# In the special case of IAM role authentication, the
# credentials in the boto3 session are the temporary STS
# credentials instead of the long-lived credentials, and the
# role ARN is not known. We have to dig deeper into the
# botocore session internals to retrieve the role ARN and
# the original long-lived credentials.
botocore_session = session._session
profile_config = botocore_session.get_scoped_config()
source_profile = profile_config.get("source_profile")
role_arn = profile_config.get("role_arn")
profile_map = botocore_session._build_profile_map()
if not (
role_arn
and source_profile
and source_profile in profile_map
):
raise AuthorizationException(
"Could not determine the IAM role ARN and source "
"profile credentials from the environment"
)
auth_method = AWSAuthenticationMethods.IAM_ROLE
source_profile_config = profile_map[source_profile]
auth_config = IAMRoleAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=source_profile_config.get(
"aws_access_key_id"
),
aws_secret_access_key=source_profile_config.get(
"aws_secret_access_key",
),
role_arn=role_arn,
)
expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
else:
if auth_method == AWSAuthenticationMethods.IAM_ROLE:
raise NotImplementedError(
f"The specified authentication method "
f"'{auth_method}' could not be used to "
"auto-configure the connector."
)
# Temporary credentials were picked up from the local
# configuration. It's not possible to determine the
# expiration time of the temporary credentials from the
# boto3 session, so we assume the default IAM role
# expiration period is used
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
)
auth_method = AWSAuthenticationMethods.STS_TOKEN
auth_config = STSTokenConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
else:
# The session picked up long-lived credentials
if not auth_method:
if role_arn:
auth_method = AWSAuthenticationMethods.IAM_ROLE
else:
# If no authentication method was specified, use the
# session token as a default recommended authentication
# method to be used with long-lived credentials.
auth_method = AWSAuthenticationMethods.SESSION_TOKEN
region_name = region_name or session.region_name
if not region_name:
raise ValueError(
"The AWS region name was not specified and could not "
"be determined from the AWS session"
)
if auth_method == AWSAuthenticationMethods.STS_TOKEN:
# Generate a session token from the long-lived credentials
# and store it in the connector secrets.
sts_client = session.client("sts")
response = sts_client.get_session_token(
DurationSeconds=DEFAULT_STS_TOKEN_EXPIRATION
)
credentials = response["Credentials"]
auth_config = STSTokenConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
seconds=DEFAULT_STS_TOKEN_EXPIRATION
)
elif auth_method == AWSAuthenticationMethods.IAM_ROLE:
if not role_arn:
raise ValueError(
"The ARN of the AWS role to assume must be "
"specified when using the IAM role authentication "
"method."
)
auth_config = IAMRoleAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
role_arn=role_arn,
)
expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
auth_config = AWSSecretKeyConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
elif auth_method == AWSAuthenticationMethods.SESSION_TOKEN:
auth_config = SessionTokenAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION
else: # auth_method is AWSAuthenticationMethods.FEDERATION_TOKEN
auth_config = FederationTokenAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION
return cls(
auth_method=auth_method,
resource_type=resource_type,
resource_id=resource_id
if resource_type not in [AWS_RESOURCE_TYPE, None]
else None,
expiration_seconds=expiration_seconds,
expires_at=expires_at,
config=auth_config,
)
def _verify(
self,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> List[str]:
"""Verify and list all the resources that the connector can access.
Args:
resource_type: The type of the resource to verify. If omitted and
if the connector supports multiple resource types, the
implementation must verify that it can authenticate and connect
to any and all of the supported resource types.
resource_id: The ID of the resource to connect to. Omitted if a
resource type is not specified. It has the same value as the
default resource ID if the supplied resource type doesn't
support multiple instances. If the supplied resource type does
allows multiple instances, this parameter may still be omitted
to fetch a list of resource IDs identifying all the resources
of the indicated type that the connector can access.
Returns:
The list of resources IDs in canonical format identifying the
resources that the connector can access. This list is empty only
if the resource type is not specified (i.e. for multi-type
connectors).
Raises:
AuthorizationException: If the connector cannot authenticate or
access the specified resource.
"""
# If the resource type is not specified, treat this the
# same as a generic AWS connector.
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type or AWS_RESOURCE_TYPE,
resource_id=resource_id,
)
# Verify that the AWS account is accessible
assert isinstance(session, boto3.Session)
sts_client = session.client("sts")
try:
sts_client.get_caller_identity()
except (ClientError, BotoCoreError) as err:
msg = f"failed to verify AWS account access: {err}"
logger.debug(msg)
raise AuthorizationException(msg) from err
if not resource_type:
return []
if resource_type == AWS_RESOURCE_TYPE:
assert resource_id is not None
return [resource_id]
if resource_type == S3_RESOURCE_TYPE:
s3_client = session.client(
"s3",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
if not resource_id:
# List all S3 buckets
try:
response = s3_client.list_buckets()
except (ClientError, BotoCoreError) as e:
msg = f"failed to list S3 buckets: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return [
f"s3://{bucket['Name']}" for bucket in response["Buckets"]
]
else:
# Check if the specified S3 bucket exists
bucket_name = self._parse_s3_resource_id(resource_id)
try:
s3_client.head_bucket(Bucket=bucket_name)
return [resource_id]
except (ClientError, BotoCoreError) as e:
msg = f"failed to fetch S3 bucket {bucket_name}: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
assert resource_id is not None
ecr_client = session.client(
"ecr",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
# List all ECR repositories
try:
repositories = ecr_client.describe_repositories()
except (ClientError, BotoCoreError) as e:
msg = f"failed to list ECR repositories: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
if len(repositories["repositories"]) == 0:
raise AuthorizationException(
"the AWS connector does not have access to any ECR "
"repositories. Please adjust the AWS permissions "
"associated with the authentication credentials to "
"include access to at least one ECR repository."
)
return [resource_id]
if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
eks_client = session.client(
"eks",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
if not resource_id:
# List all EKS clusters
try:
clusters = eks_client.list_clusters()
except (ClientError, BotoCoreError) as e:
msg = f"Failed to list EKS clusters: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return cast(List[str], clusters["clusters"])
else:
# Check if the specified EKS cluster exists
cluster_name = self._parse_eks_resource_id(resource_id)
try:
clusters = eks_client.describe_cluster(name=cluster_name)
except (ClientError, BotoCoreError) as e:
msg = f"Failed to fetch EKS cluster {cluster_name}: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return [resource_id]
return []
def _get_connector_client(
self,
resource_type: str,
resource_id: str,
) -> "ServiceConnector":
"""Get a connector instance that can be used to connect to a resource.
This method generates a client-side connector instance that can be used
to connect to a resource of the given type. The client-side connector
is configured with temporary AWS credentials extracted from the
current connector and, depending on resource type, it may also be
of a different connector type:
- a Kubernetes connector for Kubernetes clusters
- a Docker connector for Docker registries
Args:
resource_type: The type of the resources to connect to.
resource_id: The ID of a particular resource to connect to.
Returns:
An AWS, Kubernetes or Docker connector instance that can be used to
connect to the specified resource.
Raises:
AuthorizationException: If authentication failed.
ValueError: If the resource type is not supported.
RuntimeError: If the Kubernetes connector is not installed and the
resource type is Kubernetes.
"""
connector_name = ""
if self.name:
connector_name = self.name
if resource_id:
connector_name += f" ({resource_type} | {resource_id} client)"
else:
connector_name += f" ({resource_type} client)"
logger.debug(f"Getting connector client for {connector_name}")
if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
auth_method = self.auth_method
if self.auth_method in [
AWSAuthenticationMethods.SECRET_KEY,
AWSAuthenticationMethods.STS_TOKEN,
]:
if (
self.resource_type == resource_type
and self.resource_id == resource_id
):
# If the requested type and resource ID are the same as
# those configured, we can return the current connector
# instance because it's fully formed and ready to use
# to connect to the specified resource
return self
# The secret key and STS token authentication methods do not
# involve generating temporary credentials, so we can just
# use the current connector configuration
config = self.config
expires_at = self.expires_at
else:
# Get an authenticated boto3 session
session, expires_at = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
assert isinstance(session, boto3.Session)
credentials = session.get_credentials()
if (
self.auth_method == AWSAuthenticationMethods.IMPLICIT
and credentials.token is None
):
# The implicit authentication method may involve picking up
# long-lived credentials from the environment
auth_method = AWSAuthenticationMethods.SECRET_KEY
config = AWSSecretKeyConfig(
region=self.config.region,
endpoint_url=self.config.endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
else:
assert credentials.token is not None
# Use the temporary credentials extracted from the boto3
# session
auth_method = AWSAuthenticationMethods.STS_TOKEN
config = STSTokenConfig(
region=self.config.region,
endpoint_url=self.config.endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
# Create a client-side AWS connector instance that is fully formed
# and ready to use to connect to the specified resource (i.e. has
# all the necessary configuration and credentials, a resource type
# and a resource ID where applicable)
return AWSServiceConnector(
id=self.id,
name=connector_name,
auth_method=auth_method,
resource_type=resource_type,
resource_id=resource_id,
config=config,
expires_at=expires_at,
)
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
assert resource_id is not None
# Get an authenticated boto3 session
session, expires_at = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
assert isinstance(session, boto3.Session)
registry_id = self._parse_ecr_resource_id(resource_id)
ecr_client = session.client(
"ecr",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
assert isinstance(ecr_client, BaseClient)
assert registry_id is not None
try:
auth_token = ecr_client.get_authorization_token(
registryIds=[
registry_id,
]
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to get authorization token from ECR: {e}"
) from e
token = auth_token["authorizationData"][0]["authorizationToken"]
endpoint = auth_token["authorizationData"][0]["proxyEndpoint"]
# The token is base64 encoded and has the format
# "username:password"
username, token = (
base64.b64decode(token).decode("utf-8").split(":")
)
# Create a client-side Docker connector instance with the temporary
# Docker credentials
return DockerServiceConnector(
id=self.id,
name=connector_name,
auth_method=DockerAuthenticationMethods.PASSWORD,
resource_type=resource_type,
config=DockerConfiguration(
username=username,
password=token,
registry=endpoint,
),
expires_at=expires_at,
)
if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
assert resource_id is not None
# Get an authenticated boto3 session
session, expires_at = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
assert isinstance(session, boto3.Session)
cluster_name = self._parse_eks_resource_id(resource_id)
# Get a boto3 EKS client
eks_client = session.client(
"eks",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
assert isinstance(eks_client, BaseClient)
try:
cluster = eks_client.describe_cluster(name=cluster_name)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to get EKS cluster {cluster_name}: {e}"
) from e
try:
user_token = self._get_eks_bearer_token(
session=session,
cluster_id=cluster_name,
region=self.config.region,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to get EKS bearer token: {e}"
) from e
# get cluster details
cluster_arn = cluster["cluster"]["arn"]
cluster_ca_cert = cluster["cluster"]["certificateAuthority"][
"data"
]
cluster_server = cluster["cluster"]["endpoint"]
# Create a client-side Kubernetes connector instance with the
# temporary Kubernetes credentials
try:
# Import libraries only when needed
from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import (
KubernetesAuthenticationMethods,
KubernetesServiceConnector,
KubernetesTokenConfig,
)
except ImportError as e:
raise RuntimeError(
f"The Kubernetes Service Connector functionality could not "
f"be used due to missing dependencies: {e}"
)
return KubernetesServiceConnector(
id=self.id,
name=connector_name,
auth_method=KubernetesAuthenticationMethods.TOKEN,
resource_type=resource_type,
config=KubernetesTokenConfig(
cluster_name=cluster_arn,
certificate_authority=cluster_ca_cert,
server=cluster_server,
token=user_token,
),
expires_at=expires_at,
)
raise ValueError(f"Unsupported resource type: {resource_type}")
account_id: str
property
readonly
Get the AWS account ID.
Returns:
Type | Description |
---|---|
str |
The AWS account ID. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the AWS account ID could not be determined. |
get_boto3_session(self, auth_method, resource_type=None, resource_id=None)
Get a boto3 session for the specified resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_method |
str |
The authentication method to use. |
required |
resource_type |
Optional[str] |
The resource type to get a boto3 session for. |
None |
resource_id |
Optional[str] |
The resource ID to get a boto3 session for. |
None |
Returns:
Type | Description |
---|---|
Tuple[boto3.session.Session, Optional[datetime.datetime]] |
A boto3 session for the specified resource and its expiration timestamp, if applicable. |
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def get_boto3_session(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
"""Get a boto3 session for the specified resource.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to get a boto3 session for.
resource_id: The resource ID to get a boto3 session for.
Returns:
A boto3 session for the specified resource and its expiration
timestamp, if applicable.
"""
# We maintain a cache of all sessions to avoid re-authenticating
# multiple times for the same resource
key = (auth_method, resource_type, resource_id)
if key in self._session_cache:
session, expires_at = self._session_cache[key]
if expires_at is None:
return session, None
# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
if expires_at > now:
return session, expires_at
logger.debug(
f"Creating boto3 session for auth method '{auth_method}', "
f"resource type '{resource_type}' and resource ID "
f"'{resource_id}'..."
)
session, expires_at = self._authenticate(
auth_method, resource_type, resource_id
)
self._session_cache[key] = (session, expires_at)
return session, expires_at
AWSSessionPolicy (AuthenticationConfig)
pydantic-model
AWS session IAM policy configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSessionPolicy(AuthenticationConfig):
"""AWS session IAM policy configuration."""
policy_arns: Optional[List[str]] = Field(
default=None,
title="ARNs of the IAM managed policies that you want to use as a "
"managed session policy. The policies must exist in the same account "
"as the IAM user that is requesting temporary credentials.",
)
policy: Optional[str] = Field(
default=None,
title="An IAM policy in JSON format that you want to use as an inline "
"session policy",
)
FederationTokenAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy)
pydantic-model
AWS federation token authentication config.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class FederationTokenAuthenticationConfig(
AWSSecretKeyConfig, AWSSessionPolicy
):
"""AWS federation token authentication config."""
IAMRoleAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy)
pydantic-model
AWS IAM authentication config.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class IAMRoleAuthenticationConfig(AWSSecretKeyConfig, AWSSessionPolicy):
"""AWS IAM authentication config."""
role_arn: str = Field(
title="AWS IAM Role ARN",
)
STSToken (AWSSecretKey)
pydantic-model
AWS STS token.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSToken(AWSSecretKey):
"""AWS STS token."""
aws_session_token: SecretStr = Field(
title="AWS Session Token",
)
STSTokenConfig (AWSBaseConfig, STSToken)
pydantic-model
AWS STS token authentication configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSTokenConfig(AWSBaseConfig, STSToken):
"""AWS STS token authentication configuration."""
expires_at: Optional[datetime.datetime] = Field(
default=None,
title="AWS STS Token Expiration",
)
SessionTokenAuthenticationConfig (AWSSecretKeyConfig)
pydantic-model
AWS session token authentication config.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class SessionTokenAuthenticationConfig(AWSSecretKeyConfig):
"""AWS session token authentication config."""
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator)
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
"""
@property
def config(self) -> SagemakerStepOperatorConfig:
"""Returns the `SagemakerStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerStepOperatorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the SageMaker step operator.
Returns:
The settings class.
"""
return SagemakerStepOperatorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the SageMaker "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"SageMaker step operator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_docker_builds(
self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=SAGEMAKER_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
builds.append(build)
return builds
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args
session = sagemaker.Session(default_bucket=self.config.bucket)
estimator_args.setdefault(
"instance_type", settings.instance_type or "ml.m5.large"
)
estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
# Create Estimator
estimator = sagemaker.estimator.Estimator(
image_name, self.config.role, **estimator_args
)
# SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
step_name = Client().get_run_step(info.step_run_id).name
training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
suffix = random_str(4)
unique_training_job_name = f"{training_job_name}-{suffix}"
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_training_job_name = unique_training_job_name.replace(
"_", "-"
)
# Construct training input object, if necessary
inputs = None
if isinstance(settings.input_data_s3_uri, str):
inputs = sagemaker.inputs.TrainingInput(
s3_data=settings.input_data_s3_uri
)
elif isinstance(settings.input_data_s3_uri, dict):
inputs = {}
for channel, s3_uri in settings.input_data_s3_uri.items():
inputs[channel] = sagemaker.inputs.TrainingInput(
s3_data=s3_uri
)
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": sanitized_training_job_name,
}
estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=sanitized_training_job_name,
)
config: SagemakerStepOperatorConfig
property
readonly
Returns the SagemakerStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerStepOperatorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the SageMaker step operator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
get_docker_builds(self, deployment)
Gets the Docker builds required for the component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentBaseModel |
The pipeline deployment for which to get the builds. |
required |
Returns:
Type | Description |
---|---|
List[BuildConfiguration] |
The required Docker builds. |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def get_docker_builds(
self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=SAGEMAKER_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
builds.append(build)
return builds
launch(self, info, entrypoint_command, environment)
Launches a step on SageMaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
environment |
Dict[str, str] |
Environment variables to set in the step operator environment. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args
session = sagemaker.Session(default_bucket=self.config.bucket)
estimator_args.setdefault(
"instance_type", settings.instance_type or "ml.m5.large"
)
estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
# Create Estimator
estimator = sagemaker.estimator.Estimator(
image_name, self.config.role, **estimator_args
)
# SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
step_name = Client().get_run_step(info.step_run_id).name
training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
suffix = random_str(4)
unique_training_job_name = f"{training_job_name}-{suffix}"
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_training_job_name = unique_training_job_name.replace(
"_", "-"
)
# Construct training input object, if necessary
inputs = None
if isinstance(settings.input_data_s3_uri, str):
inputs = sagemaker.inputs.TrainingInput(
s3_data=settings.input_data_s3_uri
)
elif isinstance(settings.input_data_s3_uri, dict):
inputs = {}
for channel, s3_uri in settings.input_data_s3_uri.items():
inputs[channel] = sagemaker.inputs.TrainingInput(
s3_data=s3_uri
)
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": sanitized_training_job_name,
}
estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=sanitized_training_job_name,
)