Aws
zenml.integrations.aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = [
"sagemaker>=2.117.0",
"kubernetes",
"aws-profile-manager",
]
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes"]
@classmethod
def activate(cls) -> None:
"""Activate the AWS integration."""
from zenml.integrations.aws import service_connectors # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
activate()
classmethod
Activate the AWS integration.
Source code in zenml/integrations/aws/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the AWS integration."""
from zenml.integrations.aws import service_connectors # noqa
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.aws.flavors import (
AWSContainerRegistryFlavor,
SagemakerOrchestratorFlavor,
SagemakerStepOperatorFlavor,
)
return [
AWSContainerRegistryFlavor,
SagemakerStepOperatorFlavor,
SagemakerOrchestratorFlavor,
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
@property
def config(self) -> AWSContainerRegistryConfig:
"""Returns the `AWSContainerRegistryConfig` config.
Returns:
The configuration.
"""
return cast(AWSContainerRegistryConfig, self._config)
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(
r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.config.uri}."
)
return match.group(1)
def _get_ecr_client(self) -> BaseClient:
"""Get an ECR client.
If this container registry is configured with an AWS service connector,
we use that connector to create an authenticated client. Otherwise
local AWS credentials will be used.
Returns:
An ECR client.
"""
if self.connector:
try:
model = Client().get_service_connector(self.connector)
connector = service_connector_registry.instantiate_connector(
model=model
)
assert isinstance(connector, AWSServiceConnector)
return connector.get_ecr_client()
except Exception as e:
logger.error(
"Unable to get ECR client from service connector: %s",
str(e),
)
return boto3.Session().client(
"ecr",
region_name=self._get_region(),
)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
# Find repository name from image name
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
client = self._get_ecr_client()
try:
response = client.describe_repositories()
except (BotoCoreError, ClientError):
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could not find any repositories because "
"your local AWS credentials are not set. We will try to push "
"anyway, but in case it fails you need to create a repository "
f"named `{repo_name}`."
)
return
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for run a pipeline "
"using a remote orchestrator, ZenML will automatically build a "
f"docker image called `{self.config.uri}/zenml:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create a "
f"repository called `zenml` inside your Amazon ECR."
)
config: AWSContainerRegistryConfig
property
readonly
Returns the AWSContainerRegistryConfig
config.
Returns:
Type | Description |
---|---|
AWSContainerRegistryConfig |
The configuration. |
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
# Find repository name from image name
match = re.search(f"{self.config.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
client = self._get_ecr_client()
try:
response = client.describe_repositories()
except (BotoCoreError, ClientError):
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could not find any repositories because "
"your local AWS credentials are not set. We will try to push "
"anyway, but in case it fails you need to create a repository "
f"named `{repo_name}`."
)
return
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(
image_name.startswith(f"{uri}:") for uri in repo_uris
)
if not repo_exists:
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
flavors
special
AWS integration flavors.
aws_container_registry_flavor
AWS container registry flavor.
AWSContainerRegistryConfig (BaseContainerRegistryConfig)
Configuration for AWS Container Registry.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
"""Configuration for AWS Container Registry."""
@field_validator("uri")
@classmethod
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@field_validator("uri")
@classmethod
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)
AWS Container Registry flavor.
Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
"""AWS Container Registry flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_CONTAINER_REGISTRY_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(
connector_type=AWS_CONNECTOR_TYPE,
resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
resource_id_attr="uri",
)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png"
@property
def config_class(self) -> Type[AWSContainerRegistryConfig]:
"""Config class for this flavor.
Returns:
The config class.
"""
return AWSContainerRegistryConfig
@property
def implementation_class(self) -> Type["AWSContainerRegistry"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.container_registries import (
AWSContainerRegistry,
)
return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]
property
readonly
Config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[AWSContainerRegistry]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AWSContainerRegistry] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
sagemaker_orchestrator_flavor
Amazon SageMaker orchestrator flavor.
SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings)
Config for the Sagemaker orchestrator.
There are three ways to authenticate to AWS:
- By connecting a ServiceConnector
to the orchestrator,
- By configuring explicit AWS credentials aws_access_key_id
,
aws_secret_access_key
, and optional aws_auth_role_arn
,
- If none of the above are provided, unspecified credentials will be
loaded from the default AWS config.
Attributes:
Name | Type | Description |
---|---|---|
execution_role |
str |
The IAM role ARN to use for the pipeline. |
aws_access_key_id |
Optional[str] |
The AWS access key ID to use to authenticate to AWS. If not provided, the value from the default AWS config will be used. |
aws_secret_access_key |
Optional[str] |
The AWS secret access key to use to authenticate to AWS. If not provided, the value from the default AWS config will be used. |
aws_profile |
Optional[str] |
The AWS profile to use for authentication if not using service connectors or explicit credentials. If not provided, the default profile will be used. |
aws_auth_role_arn |
Optional[str] |
The ARN of an intermediate IAM role to assume when authenticating to AWS. |
region |
Optional[str] |
The AWS region where the processing job will be run. If not provided, the value from the default AWS config will be used. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig(
BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
"""Config for the Sagemaker orchestrator.
There are three ways to authenticate to AWS:
- By connecting a `ServiceConnector` to the orchestrator,
- By configuring explicit AWS credentials `aws_access_key_id`,
`aws_secret_access_key`, and optional `aws_auth_role_arn`,
- If none of the above are provided, unspecified credentials will be
loaded from the default AWS config.
Attributes:
execution_role: The IAM role ARN to use for the pipeline.
aws_access_key_id: The AWS access key ID to use to authenticate to AWS.
If not provided, the value from the default AWS config will be used.
aws_secret_access_key: The AWS secret access key to use to authenticate
to AWS. If not provided, the value from the default AWS config will
be used.
aws_profile: The AWS profile to use for authentication if not using
service connectors or explicit credentials. If not provided, the
default profile will be used.
aws_auth_role_arn: The ARN of an intermediate IAM role to assume when
authenticating to AWS.
region: The AWS region where the processing job will be run. If not
provided, the value from the default AWS config will be used.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format:
"sagemaker-{region}-{aws-account-id}".
"""
execution_role: str
aws_access_key_id: Optional[str] = SecretField(default=None)
aws_secret_access_key: Optional[str] = SecretField(default=None)
aws_profile: Optional[str] = None
aws_auth_role_arn: Optional[str] = None
region: Optional[str] = None
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
@property
def is_synchronous(self) -> bool:
"""Whether the orchestrator runs synchronous or not.
Returns:
Whether the orchestrator runs synchronous or not.
"""
return self.synchronous
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
is_synchronous: bool
property
readonly
Whether the orchestrator runs synchronous or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator runs synchronous or not. |
SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Sagemaker orchestrator.
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Sagemaker orchestrator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png"
@property
def config_class(self) -> Type[SagemakerOrchestratorConfig]:
"""Returns SagemakerOrchestratorConfig config class.
Returns:
The config class.
"""
return SagemakerOrchestratorConfig
@property
def implementation_class(self) -> Type["SagemakerOrchestrator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.orchestrators import SagemakerOrchestrator
return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]
property
readonly
Returns SagemakerOrchestratorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[SagemakerOrchestrator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerOrchestrator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
SagemakerOrchestratorSettings (BaseSettings)
Settings for the Sagemaker orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
synchronous |
bool |
If |
instance_type |
Optional[str] |
The instance type to use for the processing job. |
execution_role |
str |
The IAM role to use for the step execution. |
processor_role |
Optional[str] |
DEPRECATED: use |
volume_size_in_gb |
int |
The size of the EBS volume to use for the processing job. |
max_runtime_in_seconds |
int |
The maximum runtime in seconds for the processing job. |
tags |
Dict[str, str] |
Tags to apply to the Processor/Estimator assigned to the step. |
pipeline_tags |
Dict[str, str] |
Tags to apply to the pipeline via the sagemaker.workflow.pipeline.Pipeline.create method. |
processor_tags |
Dict[str, str] |
DEPRECATED: use |
keep_alive_period_in_seconds |
Optional[int] |
The time in seconds after which the
provisioned instance will be terminated if not used. This is only
applicable for TrainingStep type and it is not possible to use
TrainingStep type if the |
use_training_step |
Optional[bool] |
Whether to use the TrainingStep type.
It is not possible to use TrainingStep type
if the |
processor_args |
Dict[str, Any] |
Arguments that are directly passed to the SageMaker Processor for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor for a full list of arguments. For processor_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types. |
estimator_args |
Dict[str, Any] |
Arguments that are directly passed to the SageMaker Estimator for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For a list of available instance types, check https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html. |
input_data_s3_mode |
str |
How data is made available to the container. Two possible input modes: File, Pipe. |
input_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with input_data_s3_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3. |
output_data_s3_mode |
str |
How data is uploaded to the S3 bucket. Two possible output modes: EndOfJob, Continuous. |
output_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where data is uploaded after or during processing run.
e.g. s3://my-bucket/my-data/output. How data will be made available
to the container is configured with output_data_s3_mode. Two possible
input types:
- str: S3 location where data will be uploaded from a local folder
named /opt/ml/processing/output/data.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. output_one, output_two) where
specific parts of the data are stored locally for S3 upload.
Data must be available locally in /opt/ml/processing/output/data/ |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
"""Settings for the Sagemaker orchestrator.
Attributes:
synchronous: If `True`, the client running a pipeline using this
orchestrator waits until all steps finish running. If `False`,
the client returns immediately and the pipeline is executed
asynchronously. Defaults to `True`.
instance_type: The instance type to use for the processing job.
execution_role: The IAM role to use for the step execution.
processor_role: DEPRECATED: use `execution_role` instead.
volume_size_in_gb: The size of the EBS volume to use for the processing
job.
max_runtime_in_seconds: The maximum runtime in seconds for the
processing job.
tags: Tags to apply to the Processor/Estimator assigned to the step.
pipeline_tags: Tags to apply to the pipeline via the
sagemaker.workflow.pipeline.Pipeline.create method.
processor_tags: DEPRECATED: use `tags` instead.
keep_alive_period_in_seconds: The time in seconds after which the
provisioned instance will be terminated if not used. This is only
applicable for TrainingStep type and it is not possible to use
TrainingStep type if the `output_data_s3_uri` is set to Dict[str, str].
use_training_step: Whether to use the TrainingStep type.
It is not possible to use TrainingStep type
if the `output_data_s3_uri` is set to Dict[str, str] or if the
`output_data_s3_mode` != "EndOfJob".
processor_args: Arguments that are directly passed to the SageMaker
Processor for a specific step, allowing for overriding the default
settings provided when configuring the component. See
https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor
for a full list of arguments.
For processor_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
estimator_args: Arguments that are directly passed to the SageMaker
Estimator for a specific step, allowing for overriding the default
settings provided when configuring the component. See
https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
for a full list of arguments.
For a list of available instance types, check
https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html.
input_data_s3_mode: How data is made available to the container.
Two possible input modes: File, Pipe.
input_data_s3_uri: S3 URI where data is located if not locally,
e.g. s3://my-bucket/my-data/train. How data will be made available
to the container is configured with input_data_s3_mode. Two possible
input types:
- str: S3 location where training data is saved.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. training, validation, testing) where
specific parts of the data are saved in S3.
output_data_s3_mode: How data is uploaded to the S3 bucket.
Two possible output modes: EndOfJob, Continuous.
output_data_s3_uri: S3 URI where data is uploaded after or during processing run.
e.g. s3://my-bucket/my-data/output. How data will be made available
to the container is configured with output_data_s3_mode. Two possible
input types:
- str: S3 location where data will be uploaded from a local folder
named /opt/ml/processing/output/data.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. output_one, output_two) where
specific parts of the data are stored locally for S3 upload.
Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
"""
synchronous: bool = True
instance_type: Optional[str] = None
execution_role: Optional[str] = None
volume_size_in_gb: int = 30
max_runtime_in_seconds: int = 86400
tags: Dict[str, str] = {}
pipeline_tags: Dict[str, str] = {}
keep_alive_period_in_seconds: Optional[int] = 300 # 5 minutes
use_training_step: Optional[bool] = None
processor_args: Dict[str, Any] = {}
estimator_args: Dict[str, Any] = {}
input_data_s3_mode: str = "File"
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
default=None, union_mode="left_to_right"
)
output_data_s3_mode: str = DEFAULT_OUTPUT_DATA_S3_MODE
output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
default=None, union_mode="left_to_right"
)
processor_role: Optional[str] = None
processor_tags: Dict[str, str] = {}
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
("processor_role", "execution_role"), ("processor_tags", "tags")
)
@model_validator(mode="before")
def validate_model(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Check if model is configured correctly.
Args:
data: The model data.
Returns:
The validated model data.
Raises:
ValueError: If the model is configured incorrectly.
"""
use_training_step = data.get("use_training_step", True)
output_data_s3_uri = data.get("output_data_s3_uri", None)
output_data_s3_mode = data.get(
"output_data_s3_mode", DEFAULT_OUTPUT_DATA_S3_MODE
)
if use_training_step and (
isinstance(output_data_s3_uri, dict)
or (
isinstance(output_data_s3_uri, str)
and (output_data_s3_mode != DEFAULT_OUTPUT_DATA_S3_MODE)
)
):
raise ValueError(
"`use_training_step=True` is not supported when `output_data_s3_uri` is a dict or "
f"when `output_data_s3_mode` is not '{DEFAULT_OUTPUT_DATA_S3_MODE}'."
)
instance_type = data.get("instance_type", None)
if instance_type is None:
if use_training_step:
data["instance_type"] = DEFAULT_TRAINING_INSTANCE_TYPE
else:
data["instance_type"] = DEFAULT_PROCESSING_INSTANCE_TYPE
return data
validate_model(data)
classmethod
Check if model is configured correctly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Dict[str, Any] |
The model data. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The validated model data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the model is configured incorrectly. |
Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
@model_validator(mode="before")
def validate_model(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Check if model is configured correctly.
Args:
data: The model data.
Returns:
The validated model data.
Raises:
ValueError: If the model is configured incorrectly.
"""
use_training_step = data.get("use_training_step", True)
output_data_s3_uri = data.get("output_data_s3_uri", None)
output_data_s3_mode = data.get(
"output_data_s3_mode", DEFAULT_OUTPUT_DATA_S3_MODE
)
if use_training_step and (
isinstance(output_data_s3_uri, dict)
or (
isinstance(output_data_s3_uri, str)
and (output_data_s3_mode != DEFAULT_OUTPUT_DATA_S3_MODE)
)
):
raise ValueError(
"`use_training_step=True` is not supported when `output_data_s3_uri` is a dict or "
f"when `output_data_s3_mode` is not '{DEFAULT_OUTPUT_DATA_S3_MODE}'."
)
instance_type = data.get("instance_type", None)
if instance_type is None:
if use_training_step:
data["instance_type"] = DEFAULT_TRAINING_INSTANCE_TYPE
else:
data["instance_type"] = DEFAULT_PROCESSING_INSTANCE_TYPE
return data
sagemaker_step_operator_flavor
Amazon SageMaker step operator flavor.
SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings)
Config for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig(
BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
"""Config for the Sagemaker step operator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
"""
role: str
bucket: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)
Flavor for the Sagemaker step operator.
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
"""Flavor for the Sagemaker step operator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png"
@property
def config_class(self) -> Type[SagemakerStepOperatorConfig]:
"""Returns SagemakerStepOperatorConfig config class.
Returns:
The config class.
"""
return SagemakerStepOperatorConfig
@property
def implementation_class(self) -> Type["SagemakerStepOperator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.aws.step_operators import SagemakerStepOperator
return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]
property
readonly
Returns SagemakerStepOperatorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[SagemakerStepOperator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[SagemakerStepOperator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
SagemakerStepOperatorSettings (BaseSettings)
Settings for the Sagemaker step operator.
Attributes:
Name | Type | Description |
---|---|---|
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
input_data_s3_uri |
Union[str, Dict[str, str]] |
S3 URI where training data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with estimator_args.input_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3. |
estimator_args |
Dict[str, Any] |
Arguments that are directly passed to the SageMaker Estimator. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For estimator_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types. |
Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
"""Settings for the Sagemaker step operator.
Attributes:
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
input_data_s3_uri: S3 URI where training data is located if not locally,
e.g. s3://my-bucket/my-data/train. How data will be made available
to the container is configured with estimator_args.input_mode. Two possible
input types:
- str: S3 location where training data is saved.
- Dict[str, str]: (ChannelName, S3Location) which represent
channels (e.g. training, validation, testing) where
specific parts of the data are saved in S3.
estimator_args: Arguments that are directly passed to the SageMaker
Estimator. See
https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
for a full list of arguments.
For estimator_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
"""
instance_type: Optional[str] = None
experiment_name: Optional[str] = None
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
default=None, union_mode="left_to_right"
)
estimator_args: Dict[str, Any] = {}
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
"instance_type"
)
orchestrators
special
AWS Sagemaker orchestrator.
sagemaker_orchestrator
Implementation of the SageMaker orchestrator.
SagemakerOrchestrator (ContainerizedOrchestrator)
Orchestrator responsible for running pipelines on Sagemaker.
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines on Sagemaker."""
@property
def config(self) -> SagemakerOrchestratorConfig:
"""Returns the `SagemakerOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerOrchestratorConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
In the remote case, checks that the stack contains a container registry,
image builder and only remote components.
Returns:
A `StackValidator` instance.
"""
def _validate_remote_components(
stack: "Stack",
) -> Tuple[bool, str]:
for component in stack.components.values():
if not component.config.is_local:
continue
return False, (
f"The Sagemaker orchestrator runs pipelines remotely, "
f"but the '{component.name}' {component.type.value} is "
"a local stack component and will not be available in "
"the Sagemaker step.\nPlease ensure that you always "
"use non-local stack components with the Sagemaker "
"orchestrator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Sagemaker orchestrator.
Returns:
The settings class.
"""
return SagemakerOrchestratorSettings
def _get_sagemaker_session(self) -> sagemaker.Session:
"""Method to create the sagemaker session with proper authentication.
Returns:
The Sagemaker Session.
Raises:
RuntimeError: If the connector returns the wrong type for the
session.
"""
# Get authenticated session
# Option 1: Service connector
boto_session: boto3.Session
if connector := self.get_connector():
boto_session = connector.connect()
if not isinstance(boto_session, boto3.Session):
raise RuntimeError(
f"Expected to receive a `boto3.Session` object from the "
f"linked connector, but got type `{type(boto_session)}`."
)
# Option 2: Explicit configuration
# Args that are not provided will be taken from the default AWS config.
else:
boto_session = boto3.Session(
aws_access_key_id=self.config.aws_access_key_id,
aws_secret_access_key=self.config.aws_secret_access_key,
region_name=self.config.region,
profile_name=self.config.aws_profile,
)
# If a role ARN is provided for authentication, assume the role
if self.config.aws_auth_role_arn:
sts = boto_session.client("sts")
response = sts.assume_role(
RoleArn=self.config.aws_auth_role_arn,
RoleSessionName="zenml-sagemaker-orchestrator",
)
credentials = response["Credentials"]
boto_session = boto3.Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.config.region,
)
return sagemaker.Session(
boto_session=boto_session, default_bucket=self.config.bucket
)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Iterator[Dict[str, MetadataType]]:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
RuntimeError: If a connector is used that does not return a
`boto3.Session` object.
TypeError: If the network_config passed is not compatible with the
AWS SageMaker NetworkConfig class.
Yields:
A dictionary of metadata related to the pipeline run.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
# sagemaker requires pipelineName to use alphanum and hyphens only
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# replace all non-alphanum and non-hyphens with hyphens
orchestrator_run_name = re.sub(
r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
)
session = self._get_sagemaker_session()
# Sagemaker does not allow environment variables longer than 256
# characters to be passed to Processor steps. If an environment variable
# is longer than 256 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
env=environment,
)
sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = SagemakerEntrypointConfiguration.get_entrypoint_command()
arguments = (
SagemakerEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
ExecutionVariables.PIPELINE_EXECUTION_ARN
)
use_training_step = (
step_settings.use_training_step
if step_settings.use_training_step is not None
else (
self.config.use_training_step
if self.config.use_training_step is not None
else True
)
)
# Retrieve Executor arguments provided in the Step settings.
if use_training_step:
args_for_step_executor = step_settings.estimator_args or {}
else:
args_for_step_executor = step_settings.processor_args or {}
# Set default values from configured orchestrator Component to
# arguments to be used when they are not present in processor_args.
args_for_step_executor.setdefault(
"role",
step_settings.execution_role or self.config.execution_role,
)
args_for_step_executor.setdefault(
"volume_size_in_gb", step_settings.volume_size_in_gb
)
args_for_step_executor.setdefault(
"max_runtime_in_seconds", step_settings.max_runtime_in_seconds
)
tags = step_settings.tags
args_for_step_executor.setdefault(
"tags",
(
[
{"Key": key, "Value": value}
for key, value in tags.items()
]
if tags
else None
),
)
args_for_step_executor.setdefault(
"instance_type", step_settings.instance_type
)
# Set values that cannot be overwritten
args_for_step_executor["image_uri"] = image
args_for_step_executor["instance_count"] = 1
args_for_step_executor["sagemaker_session"] = session
args_for_step_executor["base_job_name"] = orchestrator_run_name
# Convert network_config to sagemaker.network.NetworkConfig if
# present
network_config = args_for_step_executor.get("network_config")
if network_config and isinstance(network_config, dict):
try:
args_for_step_executor["network_config"] = NetworkConfig(
**network_config
)
except TypeError:
# If the network_config passed is not compatible with the
# NetworkConfig class, raise a more informative error.
raise TypeError(
"Expected a sagemaker.network.NetworkConfig "
"compatible object for the network_config argument, "
"but the network_config processor argument is invalid."
"See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
"for more information about the NetworkConfig class."
)
# Construct S3 inputs to container for step
inputs = None
if step_settings.input_data_s3_uri is None:
pass
elif isinstance(step_settings.input_data_s3_uri, str):
inputs = [
ProcessingInput(
source=step_settings.input_data_s3_uri,
destination="/opt/ml/processing/input/data",
s3_input_mode=step_settings.input_data_s3_mode,
)
]
elif isinstance(step_settings.input_data_s3_uri, dict):
inputs = []
for channel, s3_uri in step_settings.input_data_s3_uri.items():
inputs.append(
ProcessingInput(
source=s3_uri,
destination=f"/opt/ml/processing/input/data/{channel}",
s3_input_mode=step_settings.input_data_s3_mode,
)
)
# Construct S3 outputs from container for step
outputs = None
output_path = None
if step_settings.output_data_s3_uri is None:
pass
elif isinstance(step_settings.output_data_s3_uri, str):
if use_training_step:
output_path = step_settings.output_data_s3_uri
else:
outputs = [
ProcessingOutput(
source="/opt/ml/processing/output/data",
destination=step_settings.output_data_s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
]
elif isinstance(step_settings.output_data_s3_uri, dict):
outputs = []
for (
channel,
s3_uri,
) in step_settings.output_data_s3_uri.items():
outputs.append(
ProcessingOutput(
source=f"/opt/ml/processing/output/data/{channel}",
destination=s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
)
if use_training_step:
# Create Estimator and TrainingStep
estimator = sagemaker.estimator.Estimator(
keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
output_path=output_path,
environment=environment,
container_entry_point=entrypoint,
**args_for_step_executor,
)
sagemaker_step = TrainingStep(
name=step_name,
depends_on=step.spec.upstream_steps,
inputs=inputs,
estimator=estimator,
)
else:
# Create Processor and ProcessingStep
processor = sagemaker.processing.Processor(
entrypoint=entrypoint,
env=environment,
**args_for_step_executor,
)
sagemaker_step = ProcessingStep(
name=step_name,
processor=processor,
depends_on=step.spec.upstream_steps,
inputs=inputs,
outputs=outputs,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
settings = cast(
SagemakerOrchestratorSettings, self.get_settings(deployment)
)
pipeline.create(
role_arn=self.config.execution_role,
tags=[
{"Key": key, "Value": value}
for key, value in settings.pipeline_tags.items()
]
if settings.pipeline_tags
else None,
)
execution = pipeline.start()
logger.warning(
"Steps can take 5-15 minutes to start running "
"when using the Sagemaker Orchestrator."
)
# Yield metadata based on the generated execution object
yield from self.compute_metadata(execution=execution)
# mainly for testing purposes, we wait for the pipeline to finish
if settings.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish... \n"
"At this point you can `Ctrl-C` out without cancelling the "
"execution."
)
try:
execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
except WaiterError:
raise RuntimeError(
"Timed out while waiting for pipeline execution to "
"finish. For long-running pipelines we recommend "
"configuring your orchestrator for asynchronous execution. "
"The following command does this for you: \n"
f"`zenml orchestrator update {self.name} "
f"--synchronous=False`"
)
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
run_metadata: Dict[str, "MetadataType"] = {
"pipeline_execution_arn": pipeline_execution_arn,
}
return run_metadata
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
"""Refreshes the status of a specific pipeline run.
Args:
run: The run that was executed by this orchestrator.
Returns:
the actual status of the pipeline job.
Raises:
AssertionError: If the run was not executed by to this orchestrator.
ValueError: If it fetches an unknown state or if we can not fetch
the orchestrator run ID.
"""
# Make sure that the stack exists and is accessible
if run.stack is None:
raise ValueError(
"The stack that the run was executed on is not available "
"anymore."
)
# Make sure that the run belongs to this orchestrator
assert (
self.id
== run.stack.components[StackComponentType.ORCHESTRATOR][0].id
)
# Initialize the Sagemaker client
session = self._get_sagemaker_session()
sagemaker_client = session.sagemaker_client
# Fetch the status of the _PipelineExecution
if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
elif run.orchestrator_run_id is not None:
run_id = run.orchestrator_run_id
else:
raise ValueError(
"Can not find the orchestrator run ID, thus can not fetch "
"the status."
)
status = sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=run_id
)["PipelineExecutionStatus"]
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
if status in ["Executing", "Stopping"]:
return ExecutionStatus.RUNNING
elif status in ["Stopped", "Failed"]:
return ExecutionStatus.FAILED
elif status in ["Succeeded"]:
return ExecutionStatus.COMPLETED
else:
raise ValueError("Unknown status for the pipeline execution.")
def compute_metadata(
self, execution: Any
) -> Iterator[Dict[str, MetadataType]]:
"""Generate run metadata based on the generated Sagemaker Execution.
Args:
execution: The corresponding _PipelineExecution object.
Yields:
A dictionary of metadata related to the pipeline run.
"""
# Metadata
metadata: Dict[str, MetadataType] = {}
# Orchestrator Run ID
if run_id := self._compute_orchestrator_run_id(execution):
metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
# URL to the Sagemaker's pipeline view
if orchestrator_url := self._compute_orchestrator_url(execution):
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
# URL to the corresponding CloudWatch page
if logs_url := self._compute_orchestrator_logs_url(execution):
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
yield metadata
@staticmethod
def _compute_orchestrator_url(
pipeline_execution: Any,
) -> Optional[str]:
"""Generate the Orchestrator Dashboard URL upon pipeline execution.
Args:
pipeline_execution: The corresponding _PipelineExecution object.
Returns:
the URL to the dashboard view in SageMaker.
"""
try:
region_name, pipeline_name, execution_id = (
dissect_pipeline_execution_arn(pipeline_execution.arn)
)
# Get the Sagemaker session
session = pipeline_execution.sagemaker_session
# List the Studio domains and get the Studio Domain ID
domains_response = session.sagemaker_client.list_domains()
studio_domain_id = domains_response["Domains"][0]["DomainId"]
return (
f"https://studio-{studio_domain_id}.studio.{region_name}."
f"sagemaker.aws/pipelines/view/{pipeline_name}/executions"
f"/{execution_id}/graph"
)
except Exception as e:
logger.warning(
f"There was an issue while extracting the pipeline url: {e}"
)
return None
@staticmethod
def _compute_orchestrator_logs_url(
pipeline_execution: Any,
) -> Optional[str]:
"""Generate the CloudWatch URL upon pipeline execution.
Args:
pipeline_execution: The corresponding _PipelineExecution object.
Returns:
the URL querying the pipeline logs in CloudWatch on AWS.
"""
try:
region_name, _, execution_id = dissect_pipeline_execution_arn(
pipeline_execution.arn
)
return (
f"https://{region_name}.console.aws.amazon.com/"
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
f"/$252Faws$252Fsagemaker$252FTrainingJobs$3FlogStreamNameFilter"
f"$3Dpipelines-{execution_id}-"
)
except Exception as e:
logger.warning(
f"There was an issue while extracting the logs url: {e}"
)
return None
@staticmethod
def _compute_orchestrator_run_id(
pipeline_execution: Any,
) -> Optional[str]:
"""Fetch the Orchestrator Run ID upon pipeline execution.
Args:
pipeline_execution: The corresponding _PipelineExecution object.
Returns:
the Execution ID of the run in SageMaker.
"""
try:
return str(pipeline_execution.arn)
except Exception as e:
logger.warning(
f"There was an issue while extracting the pipeline run ID: {e}"
)
return None
config: SagemakerOrchestratorConfig
property
readonly
Returns the SagemakerOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerOrchestratorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Sagemaker orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
In the remote case, checks that the stack contains a container registry, image builder and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
compute_metadata(self, execution)
Generate run metadata based on the generated Sagemaker Execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
execution |
Any |
The corresponding _PipelineExecution object. |
required |
Yields:
Type | Description |
---|---|
Iterator[Dict[str, Union[str, int, float, bool, Dict[Any, Any], List[Any], Set[Any], Tuple[Any, ...], zenml.metadata.metadata_types.Uri, zenml.metadata.metadata_types.Path, zenml.metadata.metadata_types.DType, zenml.metadata.metadata_types.StorageSize]]] |
A dictionary of metadata related to the pipeline run. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def compute_metadata(
self, execution: Any
) -> Iterator[Dict[str, MetadataType]]:
"""Generate run metadata based on the generated Sagemaker Execution.
Args:
execution: The corresponding _PipelineExecution object.
Yields:
A dictionary of metadata related to the pipeline run.
"""
# Metadata
metadata: Dict[str, MetadataType] = {}
# Orchestrator Run ID
if run_id := self._compute_orchestrator_run_id(execution):
metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
# URL to the Sagemaker's pipeline view
if orchestrator_url := self._compute_orchestrator_url(execution):
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
# URL to the corresponding CloudWatch page
if logs_url := self._compute_orchestrator_logs_url(execution):
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
yield metadata
fetch_status(self, run)
Refreshes the status of a specific pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run |
PipelineRunResponse |
The run that was executed by this orchestrator. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
the actual status of the pipeline job. |
Exceptions:
Type | Description |
---|---|
AssertionError |
If the run was not executed by to this orchestrator. |
ValueError |
If it fetches an unknown state or if we can not fetch the orchestrator run ID. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
"""Refreshes the status of a specific pipeline run.
Args:
run: The run that was executed by this orchestrator.
Returns:
the actual status of the pipeline job.
Raises:
AssertionError: If the run was not executed by to this orchestrator.
ValueError: If it fetches an unknown state or if we can not fetch
the orchestrator run ID.
"""
# Make sure that the stack exists and is accessible
if run.stack is None:
raise ValueError(
"The stack that the run was executed on is not available "
"anymore."
)
# Make sure that the run belongs to this orchestrator
assert (
self.id
== run.stack.components[StackComponentType.ORCHESTRATOR][0].id
)
# Initialize the Sagemaker client
session = self._get_sagemaker_session()
sagemaker_client = session.sagemaker_client
# Fetch the status of the _PipelineExecution
if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
elif run.orchestrator_run_id is not None:
run_id = run.orchestrator_run_id
else:
raise ValueError(
"Can not find the orchestrator run ID, thus can not fetch "
"the status."
)
status = sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=run_id
)["PipelineExecutionStatus"]
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
if status in ["Executing", "Stopping"]:
return ExecutionStatus.RUNNING
elif status in ["Stopped", "Failed"]:
return ExecutionStatus.FAILED
elif status in ["Succeeded"]:
return ExecutionStatus.COMPLETED
else:
raise ValueError("Unknown status for the pipeline execution.")
get_orchestrator_run_id(self)
Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run id cannot be read from the environment. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
)
get_pipeline_run_metadata(self, run_id)
Get general component-specific metadata for a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
The ID of the pipeline run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
run_metadata: Dict[str, "MetadataType"] = {
"pipeline_execution_arn": pipeline_execution_arn,
}
return run_metadata
prepare_or_run_pipeline(self, deployment, stack, environment)
Prepares or runs a pipeline on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The deployment to prepare or run. |
required |
stack |
Stack |
The stack to run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If a connector is used that does not return a
|
TypeError |
If the network_config passed is not compatible with the AWS SageMaker NetworkConfig class. |
Yields:
Type | Description |
---|---|
Iterator[Dict[str, Union[str, int, float, bool, Dict[Any, Any], List[Any], Set[Any], Tuple[Any, ...], zenml.metadata.metadata_types.Uri, zenml.metadata.metadata_types.Path, zenml.metadata.metadata_types.DType, zenml.metadata.metadata_types.StorageSize]]] |
A dictionary of metadata related to the pipeline run. |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Iterator[Dict[str, MetadataType]]:
"""Prepares or runs a pipeline on Sagemaker.
Args:
deployment: The deployment to prepare or run.
stack: The stack to run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
RuntimeError: If a connector is used that does not return a
`boto3.Session` object.
TypeError: If the network_config passed is not compatible with the
AWS SageMaker NetworkConfig class.
Yields:
A dictionary of metadata related to the pipeline run.
"""
if deployment.schedule:
logger.warning(
"The Sagemaker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
# sagemaker requires pipelineName to use alphanum and hyphens only
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# replace all non-alphanum and non-hyphens with hyphens
orchestrator_run_name = re.sub(
r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
)
session = self._get_sagemaker_session()
# Sagemaker does not allow environment variables longer than 256
# characters to be passed to Processor steps. If an environment variable
# is longer than 256 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
env=environment,
)
sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = SagemakerEntrypointConfiguration.get_entrypoint_command()
arguments = (
SagemakerEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
)
entrypoint = command + arguments
step_settings = cast(
SagemakerOrchestratorSettings, self.get_settings(step)
)
environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
ExecutionVariables.PIPELINE_EXECUTION_ARN
)
use_training_step = (
step_settings.use_training_step
if step_settings.use_training_step is not None
else (
self.config.use_training_step
if self.config.use_training_step is not None
else True
)
)
# Retrieve Executor arguments provided in the Step settings.
if use_training_step:
args_for_step_executor = step_settings.estimator_args or {}
else:
args_for_step_executor = step_settings.processor_args or {}
# Set default values from configured orchestrator Component to
# arguments to be used when they are not present in processor_args.
args_for_step_executor.setdefault(
"role",
step_settings.execution_role or self.config.execution_role,
)
args_for_step_executor.setdefault(
"volume_size_in_gb", step_settings.volume_size_in_gb
)
args_for_step_executor.setdefault(
"max_runtime_in_seconds", step_settings.max_runtime_in_seconds
)
tags = step_settings.tags
args_for_step_executor.setdefault(
"tags",
(
[
{"Key": key, "Value": value}
for key, value in tags.items()
]
if tags
else None
),
)
args_for_step_executor.setdefault(
"instance_type", step_settings.instance_type
)
# Set values that cannot be overwritten
args_for_step_executor["image_uri"] = image
args_for_step_executor["instance_count"] = 1
args_for_step_executor["sagemaker_session"] = session
args_for_step_executor["base_job_name"] = orchestrator_run_name
# Convert network_config to sagemaker.network.NetworkConfig if
# present
network_config = args_for_step_executor.get("network_config")
if network_config and isinstance(network_config, dict):
try:
args_for_step_executor["network_config"] = NetworkConfig(
**network_config
)
except TypeError:
# If the network_config passed is not compatible with the
# NetworkConfig class, raise a more informative error.
raise TypeError(
"Expected a sagemaker.network.NetworkConfig "
"compatible object for the network_config argument, "
"but the network_config processor argument is invalid."
"See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
"for more information about the NetworkConfig class."
)
# Construct S3 inputs to container for step
inputs = None
if step_settings.input_data_s3_uri is None:
pass
elif isinstance(step_settings.input_data_s3_uri, str):
inputs = [
ProcessingInput(
source=step_settings.input_data_s3_uri,
destination="/opt/ml/processing/input/data",
s3_input_mode=step_settings.input_data_s3_mode,
)
]
elif isinstance(step_settings.input_data_s3_uri, dict):
inputs = []
for channel, s3_uri in step_settings.input_data_s3_uri.items():
inputs.append(
ProcessingInput(
source=s3_uri,
destination=f"/opt/ml/processing/input/data/{channel}",
s3_input_mode=step_settings.input_data_s3_mode,
)
)
# Construct S3 outputs from container for step
outputs = None
output_path = None
if step_settings.output_data_s3_uri is None:
pass
elif isinstance(step_settings.output_data_s3_uri, str):
if use_training_step:
output_path = step_settings.output_data_s3_uri
else:
outputs = [
ProcessingOutput(
source="/opt/ml/processing/output/data",
destination=step_settings.output_data_s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
]
elif isinstance(step_settings.output_data_s3_uri, dict):
outputs = []
for (
channel,
s3_uri,
) in step_settings.output_data_s3_uri.items():
outputs.append(
ProcessingOutput(
source=f"/opt/ml/processing/output/data/{channel}",
destination=s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
)
if use_training_step:
# Create Estimator and TrainingStep
estimator = sagemaker.estimator.Estimator(
keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
output_path=output_path,
environment=environment,
container_entry_point=entrypoint,
**args_for_step_executor,
)
sagemaker_step = TrainingStep(
name=step_name,
depends_on=step.spec.upstream_steps,
inputs=inputs,
estimator=estimator,
)
else:
# Create Processor and ProcessingStep
processor = sagemaker.processing.Processor(
entrypoint=entrypoint,
env=environment,
**args_for_step_executor,
)
sagemaker_step = ProcessingStep(
name=step_name,
processor=processor,
depends_on=step.spec.upstream_steps,
inputs=inputs,
outputs=outputs,
)
sagemaker_steps.append(sagemaker_step)
# construct the pipeline from the sagemaker_steps
pipeline = Pipeline(
name=orchestrator_run_name,
steps=sagemaker_steps,
sagemaker_session=session,
)
settings = cast(
SagemakerOrchestratorSettings, self.get_settings(deployment)
)
pipeline.create(
role_arn=self.config.execution_role,
tags=[
{"Key": key, "Value": value}
for key, value in settings.pipeline_tags.items()
]
if settings.pipeline_tags
else None,
)
execution = pipeline.start()
logger.warning(
"Steps can take 5-15 minutes to start running "
"when using the Sagemaker Orchestrator."
)
# Yield metadata based on the generated execution object
yield from self.compute_metadata(execution=execution)
# mainly for testing purposes, we wait for the pipeline to finish
if settings.synchronous:
logger.info(
"Executing synchronously. Waiting for pipeline to finish... \n"
"At this point you can `Ctrl-C` out without cancelling the "
"execution."
)
try:
execution.wait(
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
)
logger.info("Pipeline completed successfully.")
except WaiterError:
raise RuntimeError(
"Timed out while waiting for pipeline execution to "
"finish. For long-running pipelines we recommend "
"configuring your orchestrator for asynchronous execution. "
"The following command does this for you: \n"
f"`zenml orchestrator update {self.name} "
f"--synchronous=False`"
)
dissect_pipeline_execution_arn(pipeline_execution_arn)
Extract region name, pipeline name, and execution id from the ARN.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_execution_arn |
str |
the pipeline execution ARN |
required |
Returns:
Type | Description |
---|---|
Tuple[Optional[str], Optional[str], Optional[str]] |
Region Name, Pipeline Name, Execution ID in order |
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def dissect_pipeline_execution_arn(
pipeline_execution_arn: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Extract region name, pipeline name, and execution id from the ARN.
Args:
pipeline_execution_arn: the pipeline execution ARN
Returns:
Region Name, Pipeline Name, Execution ID in order
"""
# Extract region_name
region_match = re.search(r"sagemaker:(.*?):", pipeline_execution_arn)
region_name = region_match.group(1) if region_match else None
# Extract pipeline_name
pipeline_match = re.search(
r"pipeline/(.*?)/execution", pipeline_execution_arn
)
pipeline_name = pipeline_match.group(1) if pipeline_match else None
# Extract execution_id
execution_match = re.search(r"execution/(.*)", pipeline_execution_arn)
execution_id = execution_match.group(1) if execution_match else None
return region_name, pipeline_name, execution_id
sagemaker_orchestrator_entrypoint_config
Entrypoint configuration for ZenML Sagemaker pipeline steps.
SagemakerEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for ZenML Sagemaker pipeline steps.
The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 256 characters allowed for Sagemaker Processor steps from their individual components.
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py
class SagemakerEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for ZenML Sagemaker pipeline steps.
The only purpose of this entrypoint configuration is to reconstruct the
environment variables that exceed the maximum length of 256 characters
allowed for Sagemaker Processor steps from their individual components.
"""
def run(self) -> None:
"""Runs the step."""
# Reconstruct the environment variables that exceed the maximum length
# of 256 characters from their individual chunks
reconstruct_environment_variables()
# Run the step
super().run()
run(self)
Runs the step.
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py
def run(self) -> None:
"""Runs the step."""
# Reconstruct the environment variables that exceed the maximum length
# of 256 characters from their individual chunks
reconstruct_environment_variables()
# Run the step
super().run()
service_connectors
special
AWS Service Connector.
aws_service_connector
AWS Service Connector.
The AWS Service Connector implements various authentication methods for AWS services:
- Explicit AWS secret key (access key, secret key)
- Explicit AWS STS tokens (access key, secret key, session token)
- IAM roles (i.e. generating temporary STS tokens on the fly by assuming an IAM role)
- IAM user federation tokens
- STS Session tokens
AWSAuthenticationMethods (StrEnum)
AWS Authentication methods.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSAuthenticationMethods(StrEnum):
"""AWS Authentication methods."""
IMPLICIT = "implicit"
SECRET_KEY = "secret-key"
STS_TOKEN = "sts-token"
IAM_ROLE = "iam-role"
SESSION_TOKEN = "session-token"
FEDERATION_TOKEN = "federation-token"
AWSBaseConfig (AuthenticationConfig)
AWS base configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSBaseConfig(AuthenticationConfig):
"""AWS base configuration."""
region: str = Field(
title="AWS Region",
)
endpoint_url: Optional[str] = Field(
default=None,
title="AWS Endpoint URL",
)
AWSImplicitConfig (AWSBaseConfig, AWSSessionPolicy)
AWS implicit configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSImplicitConfig(AWSBaseConfig, AWSSessionPolicy):
"""AWS implicit configuration."""
profile_name: Optional[str] = Field(
default=None,
title="AWS Profile Name",
)
role_arn: Optional[str] = Field(
default=None,
title="Optional AWS IAM Role ARN to assume",
)
AWSSecretKey (AuthenticationConfig)
AWS secret key credentials.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKey(AuthenticationConfig):
"""AWS secret key credentials."""
aws_access_key_id: PlainSerializedSecretStr = Field(
title="AWS Access Key ID",
description="An AWS access key ID associated with an AWS account or IAM user.",
)
aws_secret_access_key: PlainSerializedSecretStr = Field(
title="AWS Secret Access Key",
)
AWSSecretKeyConfig (AWSBaseConfig, AWSSecretKey)
AWS secret key authentication configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKeyConfig(AWSBaseConfig, AWSSecretKey):
"""AWS secret key authentication configuration."""
AWSServiceConnector (ServiceConnector)
AWS service connector.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSServiceConnector(ServiceConnector):
"""AWS service connector."""
config: AWSBaseConfig
_account_id: Optional[str] = None
_session_cache: Dict[
Tuple[str, Optional[str], Optional[str]],
Tuple[boto3.Session, Optional[datetime.datetime]],
] = {}
@classmethod
def _get_connector_type(cls) -> ServiceConnectorTypeModel:
"""Get the service connector type specification.
Returns:
The service connector type specification.
"""
return AWS_SERVICE_CONNECTOR_TYPE_SPEC
@property
def account_id(self) -> str:
"""Get the AWS account ID.
Returns:
The AWS account ID.
Raises:
AuthorizationException: If the AWS account ID could not be
determined.
"""
if self._account_id is None:
logger.debug("Getting account ID from AWS...")
try:
session, _ = self.get_boto3_session(self.auth_method)
sts_client = session.client("sts")
response = sts_client.get_caller_identity()
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to fetch the AWS account ID: {e}"
) from e
self._account_id = response["Account"]
return self._account_id
def get_boto3_session(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
"""Get a boto3 session for the specified resource.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to get a boto3 session for.
resource_id: The resource ID to get a boto3 session for.
Returns:
A boto3 session for the specified resource and its expiration
timestamp, if applicable.
"""
# We maintain a cache of all sessions to avoid re-authenticating
# multiple times for the same resource
key = (auth_method, resource_type, resource_id)
if key in self._session_cache:
session, expires_at = self._session_cache[key]
if expires_at is None:
return session, None
# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
# check if the token expires in the near future
if expires_at > now + datetime.timedelta(
minutes=BOTO3_SESSION_EXPIRATION_BUFFER
):
return session, expires_at
logger.debug(
f"Creating boto3 session for auth method '{auth_method}', "
f"resource type '{resource_type}' and resource ID "
f"'{resource_id}'..."
)
session, expires_at = self._authenticate(
auth_method, resource_type, resource_id
)
self._session_cache[key] = (session, expires_at)
return session, expires_at
def get_ecr_client(self) -> BaseClient:
"""Get an ECR client.
Raises:
ValueError: If the service connector is not able to instantiate an
ECR client.
Returns:
An ECR client.
"""
if self.resource_type and self.resource_type not in {
AWS_RESOURCE_TYPE,
DOCKER_REGISTRY_RESOURCE_TYPE,
}:
raise ValueError(
f"Unable to instantiate ECR client for a connector that is "
f"configured to provide access to a '{self.resource_type}' "
"resource type."
)
session, _ = self.get_boto3_session(
auth_method=self.auth_method,
resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
resource_id=self.config.region,
)
return session.client(
"ecr",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
def _get_iam_policy(
self,
region_id: str,
resource_type: Optional[str],
resource_id: Optional[str] = None,
) -> Optional[str]:
"""Get the IAM inline policy to use for the specified resource.
Args:
region_id: The AWS region ID to get the IAM inline policy for.
resource_type: The resource type to get the IAM inline policy for.
resource_id: The resource ID to get the IAM inline policy for.
Returns:
The IAM inline policy to use for the specified resource.
"""
if resource_type == S3_RESOURCE_TYPE:
if resource_id:
bucket = self._parse_s3_resource_id(resource_id)
resource = [
f"arn:aws:s3:::{bucket}",
f"arn:aws:s3:::{bucket}/*",
]
else:
resource = ["arn:aws:s3:::*", "arn:aws:s3:::*/*"]
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowS3BucketAccess",
"Effect": "Allow",
"Action": [
"s3:ListBucket",
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListAllMyBuckets",
"s3:GetBucketVersioning",
"s3:ListBucketVersions",
"s3:DeleteObjectVersion",
],
"Resource": resource,
},
],
}
return json.dumps(policy)
elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
if resource_id:
cluster_name = self._parse_eks_resource_id(resource_id)
resource = [
f"arn:aws:eks:{region_id}:*:cluster/{cluster_name}",
]
else:
resource = [f"arn:aws:eks:{region_id}:*:cluster/*"]
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowEKSClusterAccess",
"Effect": "Allow",
"Action": [
"eks:ListClusters",
"eks:DescribeCluster",
],
"Resource": resource,
},
],
}
return json.dumps(policy)
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
resource = [
f"arn:aws:ecr:{region_id}:*:repository/*",
f"arn:aws:ecr:{region_id}:*:repository",
]
policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowECRRepositoryAccess",
"Effect": "Allow",
"Action": [
"ecr:DescribeRegistry",
"ecr:DescribeRepositories",
"ecr:ListRepositories",
"ecr:BatchGetImage",
"ecr:DescribeImages",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:InitiateLayerUpload",
"ecr:UploadLayerPart",
"ecr:CompleteLayerUpload",
"ecr:PutImage",
],
"Resource": resource,
},
{
"Sid": "AllowECRRepositoryGetToken",
"Effect": "Allow",
"Action": [
"ecr:GetAuthorizationToken",
],
"Resource": ["*"],
},
],
}
return json.dumps(policy)
return None
def _authenticate(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
"""Authenticate to AWS and return a boto3 session.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to authenticate for.
resource_id: The resource ID to authenticate for.
Returns:
An authenticated boto3 session and the expiration time of the
temporary credentials if applicable.
Raises:
AuthorizationException: If the IAM role authentication method is
used and the role cannot be assumed.
NotImplementedError: If the authentication method is not supported.
"""
cfg = self.config
policy_kwargs: Dict[str, Any] = {}
if auth_method == AWSAuthenticationMethods.IMPLICIT:
self._check_implicit_auth_method_allowed()
assert isinstance(cfg, AWSImplicitConfig)
# Create a boto3 session and use the default credentials provider
session = boto3.Session(
profile_name=cfg.profile_name, region_name=cfg.region
)
if cfg.role_arn:
# If an IAM role is configured, assume it
policy = cfg.policy
if not cfg.policy and not cfg.policy_arns:
policy = self._get_iam_policy(
region_id=cfg.region,
resource_type=resource_type,
resource_id=resource_id,
)
if policy:
policy_kwargs["Policy"] = policy
elif cfg.policy_arns:
policy_kwargs["PolicyArns"] = cfg.policy_arns
sts = session.client("sts", region_name=cfg.region)
session_name = "zenml-connector"
if self.id:
session_name += f"-{self.id}"
try:
response = sts.assume_role(
RoleArn=cfg.role_arn,
RoleSessionName=session_name,
DurationSeconds=self.expiration_seconds,
**policy_kwargs,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to assume IAM role {cfg.role_arn} "
f"using the implicit AWS credentials: {e}"
) from e
session = boto3.Session(
aws_access_key_id=response["Credentials"]["AccessKeyId"],
aws_secret_access_key=response["Credentials"][
"SecretAccessKey"
],
aws_session_token=response["Credentials"]["SessionToken"],
)
expiration = response["Credentials"]["Expiration"]
# Add the UTC timezone to the expiration time
expiration = expiration.replace(tzinfo=datetime.timezone.utc)
return session, expiration
credentials = session.get_credentials()
if not credentials:
raise AuthorizationException(
"Failed to get AWS credentials from the default provider. "
"Please check your AWS configuration or attached IAM role."
)
if credentials.token:
# Temporary credentials were generated. It's not possible to
# determine the expiration time of the temporary credentials
# from the boto3 session, so we assume the default IAM role
# expiration date is used
expiration_time = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
)
return session, expiration_time
return session, None
elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
assert isinstance(cfg, AWSSecretKeyConfig)
# Create a boto3 session using long-term AWS credentials
session = boto3.Session(
aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
region_name=cfg.region,
)
return session, None
elif auth_method == AWSAuthenticationMethods.STS_TOKEN:
assert isinstance(cfg, STSTokenConfig)
# Create a boto3 session using a temporary AWS STS token
session = boto3.Session(
aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
aws_session_token=cfg.aws_session_token.get_secret_value(),
region_name=cfg.region,
)
return session, self.expires_at
elif auth_method in [
AWSAuthenticationMethods.IAM_ROLE,
AWSAuthenticationMethods.SESSION_TOKEN,
AWSAuthenticationMethods.FEDERATION_TOKEN,
]:
assert isinstance(cfg, AWSSecretKey)
# Create a boto3 session
session = boto3.Session(
aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
region_name=cfg.region,
)
sts = session.client("sts", region_name=cfg.region)
session_name = "zenml-connector"
if self.id:
session_name += f"-{self.id}"
# Next steps are different for each authentication method
# The IAM role and federation token authentication methods
# accept a managed IAM policy that restricts/grants permissions.
# If one isn't explicitly configured, we generate one based on the
# resource specified by the resource type and ID (if present).
if auth_method in [
AWSAuthenticationMethods.IAM_ROLE,
AWSAuthenticationMethods.FEDERATION_TOKEN,
]:
assert isinstance(cfg, AWSSessionPolicy)
policy = cfg.policy
if not cfg.policy and not cfg.policy_arns:
policy = self._get_iam_policy(
region_id=cfg.region,
resource_type=resource_type,
resource_id=resource_id,
)
if policy:
policy_kwargs["Policy"] = policy
elif cfg.policy_arns:
policy_kwargs["PolicyArns"] = cfg.policy_arns
if auth_method == AWSAuthenticationMethods.IAM_ROLE:
assert isinstance(cfg, IAMRoleAuthenticationConfig)
try:
response = sts.assume_role(
RoleArn=cfg.role_arn,
RoleSessionName=session_name,
DurationSeconds=self.expiration_seconds,
**policy_kwargs,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to assume IAM role {cfg.role_arn} "
f"using the AWS credentials configured in the "
f"connector: {e}"
) from e
else:
assert isinstance(cfg, FederationTokenAuthenticationConfig)
try:
response = sts.get_federation_token(
Name=session_name[:32],
DurationSeconds=self.expiration_seconds,
**policy_kwargs,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
"Failed to get federation token "
"using the AWS credentials configured in the "
f"connector: {e}"
) from e
else:
assert isinstance(cfg, SessionTokenAuthenticationConfig)
try:
response = sts.get_session_token(
DurationSeconds=self.expiration_seconds,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
"Failed to get session token "
"using the AWS credentials configured in the "
f"connector: {e}"
) from e
session = boto3.Session(
aws_access_key_id=response["Credentials"]["AccessKeyId"],
aws_secret_access_key=response["Credentials"][
"SecretAccessKey"
],
aws_session_token=response["Credentials"]["SessionToken"],
)
expiration = response["Credentials"]["Expiration"]
# Add the UTC timezone to the expiration time
expiration = expiration.replace(tzinfo=datetime.timezone.utc)
return session, expiration
raise NotImplementedError(
f"Authentication method '{auth_method}' is not supported by "
"the AWS connector."
)
@classmethod
def _get_eks_bearer_token(
cls,
session: boto3.Session,
cluster_id: str,
region: str,
) -> str:
"""Generate a bearer token for authenticating to the EKS API server.
Based on: https://github.com/kubernetes-sigs/aws-iam-authenticator/blob/master/README.md#api-authorization-from-outside-a-cluster
Args:
session: An authenticated boto3 session to use for generating the
token.
cluster_id: The name of the EKS cluster.
region: The AWS region the EKS cluster is in.
Returns:
A bearer token for authenticating to the EKS API server.
"""
STS_TOKEN_EXPIRES_IN = 60
client = session.client("sts", region_name=region)
service_id = client.meta.service_model.service_id
signer = RequestSigner(
service_id,
region,
"sts",
"v4",
session.get_credentials(),
session.events,
)
params = {
"method": "GET",
"url": f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15",
"body": {},
"headers": {"x-k8s-aws-id": cluster_id},
"context": {},
}
signed_url = signer.generate_presigned_url(
params,
region_name=region,
expires_in=STS_TOKEN_EXPIRES_IN,
operation_name="",
)
base64_url = base64.urlsafe_b64encode(
signed_url.encode("utf-8")
).decode("utf-8")
# remove any base64 encoding padding:
return "k8s-aws-v1." + re.sub(r"=*", "", base64_url)
def _parse_s3_resource_id(self, resource_id: str) -> str:
"""Validate and convert an S3 resource ID to an S3 bucket name.
Args:
resource_id: The resource ID to convert.
Returns:
The S3 bucket name.
Raises:
ValueError: If the provided resource ID is not a valid S3 bucket
name, ARN or URI.
"""
# The resource ID could mean different things:
#
# - an S3 bucket ARN
# - an S3 bucket URI
# - the S3 bucket name
#
# We need to extract the bucket name from the provided resource ID
bucket_name: Optional[str] = None
if re.match(
r"^arn:aws:s3:::[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9](/.*)*$",
resource_id,
):
# The resource ID is an S3 bucket ARN
bucket_name = resource_id.split(":")[-1].split("/")[0]
elif re.match(
r"^s3://[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9](/.*)*$",
resource_id,
):
# The resource ID is an S3 bucket URI
bucket_name = resource_id.split("/")[2]
elif re.match(
r"^[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9]$",
resource_id,
):
# The resource ID is the S3 bucket name
bucket_name = resource_id
else:
raise ValueError(
f"Invalid resource ID for an S3 bucket: {resource_id}. "
f"Supported formats are:\n"
f"S3 bucket ARN: arn:aws:s3:::<bucket-name>\n"
f"S3 bucket URI: s3://<bucket-name>\n"
f"S3 bucket name: <bucket-name>"
)
return bucket_name
def _parse_ecr_resource_id(
self,
resource_id: str,
) -> str:
"""Validate and convert an ECR resource ID to an ECR registry ID.
Args:
resource_id: The resource ID to convert.
Returns:
The ECR registry ID (AWS account ID).
Raises:
ValueError: If the provided resource ID is not a valid ECR
repository ARN or URI.
"""
# The resource ID could mean different things:
#
# - an ECR repository ARN
# - an ECR repository URI
#
# We need to extract the region ID and registry ID from
# the provided resource ID
config_region_id = self.config.region
region_id: Optional[str] = None
if re.match(
r"^arn:aws:ecr:[a-z0-9-]+:\d{12}:repository(/.+)*$",
resource_id,
):
# The resource ID is an ECR repository ARN
registry_id = resource_id.split(":")[4]
region_id = resource_id.split(":")[3]
elif re.match(
r"^(http[s]?://)?\d{12}\.dkr\.ecr\.[a-z0-9-]+\.amazonaws\.com(/.+)*$",
resource_id,
):
# The resource ID is an ECR repository URI
registry_id = resource_id.split(".")[0].split("/")[-1]
region_id = resource_id.split(".")[3]
else:
raise ValueError(
f"Invalid resource ID for a ECR registry: {resource_id}. "
f"Supported formats are:\n"
f"ECR repository ARN: arn:aws:ecr:<region>:<account-id>:repository[/<repository-name>]\n"
f"ECR repository URI: [https://]<account-id>.dkr.ecr.<region>.amazonaws.com[/<repository-name>]"
)
# If the connector is configured with a region and the resource ID
# is an ECR repository ARN or URI that specifies a different region
# we raise an error
if region_id and region_id != config_region_id:
raise ValueError(
f"The AWS region for the {resource_id} ECR repository region "
f"'{region_id}' does not match the region configured in "
f"the connector: '{config_region_id}'."
)
return registry_id
def _parse_eks_resource_id(self, resource_id: str) -> str:
"""Validate and convert an EKS resource ID to an AWS region and EKS cluster name.
Args:
resource_id: The resource ID to convert.
Returns:
The EKS cluster name.
Raises:
ValueError: If the provided resource ID is not a valid EKS cluster
name or ARN.
"""
# The resource ID could mean different things:
#
# - an EKS cluster ARN
# - an EKS cluster ID
#
# We need to extract the cluster name and region ID from the
# provided resource ID
config_region_id = self.config.region
cluster_name: Optional[str] = None
region_id: Optional[str] = None
if re.match(
r"^arn:aws:eks:[a-z0-9-]+:\d{12}:cluster/[0-9A-Za-z][A-Za-z0-9\-_]*$",
resource_id,
):
# The resource ID is an EKS cluster ARN
cluster_name = resource_id.split("/")[-1]
region_id = resource_id.split(":")[3]
elif re.match(
r"^[0-9A-Za-z][A-Za-z0-9\-_]*$",
resource_id,
):
# Assume the resource ID is an EKS cluster name
cluster_name = resource_id
else:
raise ValueError(
f"Invalid resource ID for a EKS cluster: {resource_id}. "
f"Supported formats are:\n"
f"EKS cluster ARN: arn:aws:eks:<region>:<account-id>:cluster/<cluster-name>\n"
f"ECR cluster name: <cluster-name>"
)
# If the connector is configured with a region and the resource ID
# is an EKS registry ARN or URI that specifies a different region
# we raise an error
if region_id and region_id != config_region_id:
raise ValueError(
f"The AWS region for the {resource_id} EKS cluster "
f"({region_id}) does not match the region configured in "
f"the connector ({config_region_id})."
)
return cluster_name
def _canonical_resource_id(
self, resource_type: str, resource_id: str
) -> str:
"""Convert a resource ID to its canonical form.
Args:
resource_type: The resource type to canonicalize.
resource_id: The resource ID to canonicalize.
Returns:
The canonical resource ID.
"""
if resource_type == S3_RESOURCE_TYPE:
bucket = self._parse_s3_resource_id(resource_id)
return f"s3://{bucket}"
elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
cluster_name = self._parse_eks_resource_id(resource_id)
return cluster_name
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
registry_id = self._parse_ecr_resource_id(
resource_id,
)
return f"{registry_id}.dkr.ecr.{self.config.region}.amazonaws.com"
else:
return resource_id
def _get_default_resource_id(self, resource_type: str) -> str:
"""Get the default resource ID for a resource type.
Args:
resource_type: The type of the resource to get a default resource ID
for. Only called with resource types that do not support
multiple instances.
Returns:
The default resource ID for the resource type.
Raises:
RuntimeError: If the ECR registry ID (AWS account ID)
cannot be retrieved from AWS because the connector is not
authorized.
"""
if resource_type == AWS_RESOURCE_TYPE:
return self.config.region
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
# we need to get the account ID (same as registry ID) from the
# caller identity AWS service
account_id = self.account_id
return f"{account_id}.dkr.ecr.{self.config.region}.amazonaws.com"
raise RuntimeError(
f"Default resource ID for '{resource_type}' not available."
)
def _connect_to_resource(
self,
**kwargs: Any,
) -> Any:
"""Authenticate and connect to an AWS resource.
Initialize and return a session or client object depending on the
connector configuration:
- initialize and return a boto3 session if the resource type
is a generic AWS resource
- initialize and return a boto3 client for an S3 resource type
For the Docker and Kubernetes resource types, the connector does not
support connecting to the resource directly. Instead, the connector
supports generating a connector client object for the resource type
in question.
Args:
kwargs: Additional implementation specific keyword arguments to pass
to the session or client constructor.
Returns:
A boto3 session for AWS generic resources and a boto3 S3 client for
S3 resources.
Raises:
NotImplementedError: If the connector instance does not support
directly connecting to the indicated resource type.
"""
resource_type = self.resource_type
resource_id = self.resource_id
assert resource_type is not None
assert resource_id is not None
# Regardless of the resource type, we must authenticate to AWS first
# before we can connect to any AWS resource
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
if resource_type == S3_RESOURCE_TYPE:
# Validate that the resource ID is a valid S3 bucket name
self._parse_s3_resource_id(resource_id)
# Create an S3 client for the bucket
client = session.client(
"s3",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
# There is no way to retrieve the credentials from the S3 client
# but some consumers need them to configure 3rd party services.
# We therefore store the credentials in the client object so that
# they can be retrieved later.
client.credentials = session.get_credentials()
return client
if resource_type == AWS_RESOURCE_TYPE:
return session
raise NotImplementedError(
f"Connecting to {resource_type} resources is not directly "
"supported by the AWS connector. Please call the "
f"`get_connector_client` method to get a {resource_type} connector "
"instance for the resource."
)
def _configure_local_client(
self,
profile_name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Configure a local client to authenticate and connect to a resource.
This method uses the connector's configuration to configure a local
client or SDK installed on the localhost for the indicated resource.
Args:
profile_name: The name of the AWS profile to use. If not specified,
a profile name is generated based on the first 8 digits of the
connector's UUID in the form 'zenml-<uuid[:8]>'. If a profile
with the given or generated name already exists, the profile is
overwritten.
kwargs: Additional implementation specific keyword arguments to use
to configure the client.
Raises:
NotImplementedError: If the connector instance does not support
local configuration for the configured resource type or
authentication method.registry
"""
resource_type = self.resource_type
if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=self.resource_id,
)
# Configure a new AWS SDK profile with the credentials
# from the session using the aws-profile-manager package
# Generate a profile name based on the first 8 digits from the
# connector UUID, if one is not supplied
aws_profile_name = profile_name or f"zenml-{str(self.id)[:8]}"
common = Common()
users_home = common.get_users_home()
all_profiles = common.get_all_profiles(users_home)
credentials = session.get_credentials()
all_profiles[aws_profile_name] = {
"region": self.config.region,
"aws_access_key_id": credentials.access_key,
"aws_secret_access_key": credentials.secret_key,
}
if credentials.token:
all_profiles[aws_profile_name]["aws_session_token"] = (
credentials.token
)
aws_credentials_path = os.path.join(
users_home, ".aws", "credentials"
)
# Create the file as well as the parent dir if needed.
dirname = os.path.split(aws_credentials_path)[0]
if not os.path.isdir(dirname):
os.makedirs(dirname)
with os.fdopen(
os.open(aws_credentials_path, os.O_WRONLY | os.O_CREAT, 0o600),
"w",
):
pass
# Write the credentials to the file
common.rewrite_credentials_file(all_profiles, users_home)
logger.info(
f"Configured local AWS SDK profile '{aws_profile_name}'."
)
return
raise NotImplementedError(
f"Configuring the local client for {resource_type} resources is "
"not directly supported by the AWS connector. Please call the "
f"`get_connector_client` method to get a {resource_type} connector "
"instance for the resource."
)
@classmethod
def _auto_configure(
cls,
auth_method: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
region_name: Optional[str] = None,
profile_name: Optional[str] = None,
role_arn: Optional[str] = None,
**kwargs: Any,
) -> "AWSServiceConnector":
"""Auto-configure the connector.
Instantiate an AWS connector with a configuration extracted from the
authentication configuration available in the environment (e.g.
environment variables or local AWS client/SDK configuration files).
Args:
auth_method: The particular authentication method to use. If not
specified, the connector implementation must decide which
authentication method to use or raise an exception.
resource_type: The type of resource to configure.
resource_id: The ID of the resource to configure. The
implementation may choose to either require or ignore this
parameter if it does not support or detect an resource type that
supports multiple instances.
region_name: The name of the AWS region to use. If not specified,
the implicit region is used.
profile_name: The name of the AWS profile to use. If not specified,
the implicit profile is used.
role_arn: The ARN of the AWS role to assume. Applicable only if the
IAM role authentication method is specified or long-term
credentials are discovered.
kwargs: Additional implementation specific keyword arguments to use.
Returns:
An AWS connector instance configured with authentication credentials
automatically extracted from the environment.
Raises:
NotImplementedError: If the connector implementation does not
support auto-configuration for the specified authentication
method.
ValueError: If the supplied arguments are not valid.
AuthorizationException: If no AWS credentials can be loaded from
the environment.
"""
auth_config: AWSBaseConfig
expiration_seconds: Optional[int] = None
expires_at: Optional[datetime.datetime] = None
if auth_method == AWSAuthenticationMethods.IMPLICIT:
if region_name is None:
raise ValueError(
"The AWS region name must be specified when using the "
"implicit authentication method"
)
auth_config = AWSImplicitConfig(
profile_name=profile_name,
region=region_name,
)
else:
# Initialize an AWS session with the default configuration loaded
# from the environment.
session = boto3.Session(
profile_name=profile_name, region_name=region_name
)
region_name = region_name or session.region_name
if not region_name:
raise ValueError(
"The AWS region name was not specified and could not "
"be determined from the AWS session"
)
endpoint_url = session._session.get_config_variable("endpoint_url")
# Extract the AWS credentials from the session and store them in
# the connector secrets.
credentials = session.get_credentials()
if not credentials:
raise AuthorizationException(
"Could not determine the AWS credentials from the "
"environment"
)
if credentials.token:
# The session picked up temporary STS credentials
if auth_method and auth_method not in [
None,
AWSAuthenticationMethods.STS_TOKEN,
AWSAuthenticationMethods.IAM_ROLE,
]:
raise NotImplementedError(
f"The specified authentication method '{auth_method}' "
"could not be used to auto-configure the connector. "
)
if (
credentials.method == "assume-role"
and auth_method != AWSAuthenticationMethods.STS_TOKEN
):
# In the special case of IAM role authentication, the
# credentials in the boto3 session are the temporary STS
# credentials instead of the long-lived credentials, and the
# role ARN is not known. We have to dig deeper into the
# botocore session internals to retrieve the role ARN and
# the original long-lived credentials.
botocore_session = session._session
profile_config = botocore_session.get_scoped_config()
source_profile = profile_config.get("source_profile")
role_arn = profile_config.get("role_arn")
profile_map = botocore_session._build_profile_map()
if not (
role_arn
and source_profile
and source_profile in profile_map
):
raise AuthorizationException(
"Could not determine the IAM role ARN and source "
"profile credentials from the environment"
)
auth_method = AWSAuthenticationMethods.IAM_ROLE
source_profile_config = profile_map[source_profile]
auth_config = IAMRoleAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=source_profile_config.get(
"aws_access_key_id"
),
aws_secret_access_key=source_profile_config.get(
"aws_secret_access_key",
),
role_arn=role_arn,
)
expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
else:
if auth_method == AWSAuthenticationMethods.IAM_ROLE:
raise NotImplementedError(
f"The specified authentication method "
f"'{auth_method}' could not be used to "
"auto-configure the connector."
)
# Temporary credentials were picked up from the local
# configuration. It's not possible to determine the
# expiration time of the temporary credentials from the
# boto3 session, so we assume the default IAM role
# expiration period is used
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
)
auth_method = AWSAuthenticationMethods.STS_TOKEN
auth_config = STSTokenConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
else:
# The session picked up long-lived credentials
if not auth_method:
if role_arn:
auth_method = AWSAuthenticationMethods.IAM_ROLE
else:
# If no authentication method was specified, use the
# session token as a default recommended authentication
# method to be used with long-lived credentials.
auth_method = AWSAuthenticationMethods.SESSION_TOKEN
region_name = region_name or session.region_name
if not region_name:
raise ValueError(
"The AWS region name was not specified and could not "
"be determined from the AWS session"
)
if auth_method == AWSAuthenticationMethods.STS_TOKEN:
# Generate a session token from the long-lived credentials
# and store it in the connector secrets.
sts_client = session.client("sts")
response = sts_client.get_session_token(
DurationSeconds=DEFAULT_STS_TOKEN_EXPIRATION
)
credentials = response["Credentials"]
auth_config = STSTokenConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
seconds=DEFAULT_STS_TOKEN_EXPIRATION
)
elif auth_method == AWSAuthenticationMethods.IAM_ROLE:
if not role_arn:
raise ValueError(
"The ARN of the AWS role to assume must be "
"specified when using the IAM role authentication "
"method."
)
auth_config = IAMRoleAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
role_arn=role_arn,
)
expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
auth_config = AWSSecretKeyConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
elif auth_method == AWSAuthenticationMethods.SESSION_TOKEN:
auth_config = SessionTokenAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION
else: # auth_method is AWSAuthenticationMethods.FEDERATION_TOKEN
auth_config = FederationTokenAuthenticationConfig(
region=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION
return cls(
auth_method=auth_method,
resource_type=resource_type,
resource_id=resource_id
if resource_type not in [AWS_RESOURCE_TYPE, None]
else None,
expiration_seconds=expiration_seconds,
expires_at=expires_at,
config=auth_config,
)
def _verify(
self,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> List[str]:
"""Verify and list all the resources that the connector can access.
Args:
resource_type: The type of the resource to verify. If omitted and
if the connector supports multiple resource types, the
implementation must verify that it can authenticate and connect
to any and all of the supported resource types.
resource_id: The ID of the resource to connect to. Omitted if a
resource type is not specified. It has the same value as the
default resource ID if the supplied resource type doesn't
support multiple instances. If the supplied resource type does
allows multiple instances, this parameter may still be omitted
to fetch a list of resource IDs identifying all the resources
of the indicated type that the connector can access.
Returns:
The list of resources IDs in canonical format identifying the
resources that the connector can access. This list is empty only
if the resource type is not specified (i.e. for multi-type
connectors).
Raises:
AuthorizationException: If the connector cannot authenticate or
access the specified resource.
"""
# If the resource type is not specified, treat this the
# same as a generic AWS connector.
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type or AWS_RESOURCE_TYPE,
resource_id=resource_id,
)
# Verify that the AWS account is accessible
assert isinstance(session, boto3.Session)
sts_client = session.client("sts")
try:
sts_client.get_caller_identity()
except (ClientError, BotoCoreError) as err:
msg = f"failed to verify AWS account access: {err}"
logger.debug(msg)
raise AuthorizationException(msg) from err
if not resource_type:
return []
if resource_type == AWS_RESOURCE_TYPE:
assert resource_id is not None
return [resource_id]
if resource_type == S3_RESOURCE_TYPE:
s3_client = session.client(
"s3",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
if not resource_id:
# List all S3 buckets
try:
response = s3_client.list_buckets()
except (ClientError, BotoCoreError) as e:
msg = f"failed to list S3 buckets: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return [
f"s3://{bucket['Name']}" for bucket in response["Buckets"]
]
else:
# Check if the specified S3 bucket exists
bucket_name = self._parse_s3_resource_id(resource_id)
try:
s3_client.head_bucket(Bucket=bucket_name)
return [resource_id]
except (ClientError, BotoCoreError) as e:
msg = f"failed to fetch S3 bucket {bucket_name}: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
assert resource_id is not None
ecr_client = session.client(
"ecr",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
# List all ECR repositories
try:
repositories = ecr_client.describe_repositories()
except (ClientError, BotoCoreError) as e:
msg = f"failed to list ECR repositories: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
if len(repositories["repositories"]) == 0:
raise AuthorizationException(
"the AWS connector does not have access to any ECR "
"repositories. Please adjust the AWS permissions "
"associated with the authentication credentials to "
"include access to at least one ECR repository."
)
return [resource_id]
if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
eks_client = session.client(
"eks",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
if not resource_id:
# List all EKS clusters
try:
clusters = eks_client.list_clusters()
except (ClientError, BotoCoreError) as e:
msg = f"Failed to list EKS clusters: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return cast(List[str], clusters["clusters"])
else:
# Check if the specified EKS cluster exists
cluster_name = self._parse_eks_resource_id(resource_id)
try:
clusters = eks_client.describe_cluster(name=cluster_name)
except (ClientError, BotoCoreError) as e:
msg = f"Failed to fetch EKS cluster {cluster_name}: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return [resource_id]
return []
def _get_connector_client(
self,
resource_type: str,
resource_id: str,
) -> "ServiceConnector":
"""Get a connector instance that can be used to connect to a resource.
This method generates a client-side connector instance that can be used
to connect to a resource of the given type. The client-side connector
is configured with temporary AWS credentials extracted from the
current connector and, depending on resource type, it may also be
of a different connector type:
- a Kubernetes connector for Kubernetes clusters
- a Docker connector for Docker registries
Args:
resource_type: The type of the resources to connect to.
resource_id: The ID of a particular resource to connect to.
Returns:
An AWS, Kubernetes or Docker connector instance that can be used to
connect to the specified resource.
Raises:
AuthorizationException: If authentication failed.
ValueError: If the resource type is not supported.
RuntimeError: If the Kubernetes connector is not installed and the
resource type is Kubernetes.
"""
connector_name = ""
if self.name:
connector_name = self.name
if resource_id:
connector_name += f" ({resource_type} | {resource_id} client)"
else:
connector_name += f" ({resource_type} client)"
logger.debug(f"Getting connector client for {connector_name}")
if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
auth_method = self.auth_method
if self.auth_method in [
AWSAuthenticationMethods.SECRET_KEY,
AWSAuthenticationMethods.STS_TOKEN,
]:
if (
self.resource_type == resource_type
and self.resource_id == resource_id
):
# If the requested type and resource ID are the same as
# those configured, we can return the current connector
# instance because it's fully formed and ready to use
# to connect to the specified resource
return self
# The secret key and STS token authentication methods do not
# involve generating temporary credentials, so we can just
# use the current connector configuration
config = self.config
expires_at = self.expires_at
else:
# Get an authenticated boto3 session
session, expires_at = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
assert isinstance(session, boto3.Session)
credentials = session.get_credentials()
if (
self.auth_method == AWSAuthenticationMethods.IMPLICIT
and credentials.token is None
):
# The implicit authentication method may involve picking up
# long-lived credentials from the environment
auth_method = AWSAuthenticationMethods.SECRET_KEY
config = AWSSecretKeyConfig(
region=self.config.region,
endpoint_url=self.config.endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
)
else:
assert credentials.token is not None
# Use the temporary credentials extracted from the boto3
# session
auth_method = AWSAuthenticationMethods.STS_TOKEN
config = STSTokenConfig(
region=self.config.region,
endpoint_url=self.config.endpoint_url,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
# Create a client-side AWS connector instance that is fully formed
# and ready to use to connect to the specified resource (i.e. has
# all the necessary configuration and credentials, a resource type
# and a resource ID where applicable)
return AWSServiceConnector(
id=self.id,
name=connector_name,
auth_method=auth_method,
resource_type=resource_type,
resource_id=resource_id,
config=config,
expires_at=expires_at,
)
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
assert resource_id is not None
# Get an authenticated boto3 session
session, expires_at = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
assert isinstance(session, boto3.Session)
registry_id = self._parse_ecr_resource_id(resource_id)
ecr_client = session.client(
"ecr",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
assert isinstance(ecr_client, BaseClient)
assert registry_id is not None
try:
auth_token = ecr_client.get_authorization_token(
registryIds=[
registry_id,
]
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to get authorization token from ECR: {e}"
) from e
token = auth_token["authorizationData"][0]["authorizationToken"]
endpoint = auth_token["authorizationData"][0]["proxyEndpoint"]
# The token is base64 encoded and has the format
# "username:password"
username, token = (
base64.b64decode(token).decode("utf-8").split(":")
)
# Create a client-side Docker connector instance with the temporary
# Docker credentials
return DockerServiceConnector(
id=self.id,
name=connector_name,
auth_method=DockerAuthenticationMethods.PASSWORD,
resource_type=resource_type,
config=DockerConfiguration(
username=username,
password=token,
registry=endpoint,
),
expires_at=expires_at,
)
if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
assert resource_id is not None
# Get an authenticated boto3 session
session, _ = self.get_boto3_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
assert isinstance(session, boto3.Session)
cluster_name = self._parse_eks_resource_id(resource_id)
# Get a boto3 EKS client
eks_client = session.client(
"eks",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
assert isinstance(eks_client, BaseClient)
try:
cluster = eks_client.describe_cluster(name=cluster_name)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to get EKS cluster {cluster_name}: {e}"
) from e
try:
user_token = self._get_eks_bearer_token(
session=session,
cluster_id=cluster_name,
region=self.config.region,
)
except (ClientError, BotoCoreError) as e:
raise AuthorizationException(
f"Failed to get EKS bearer token: {e}"
) from e
# Kubernetes authentication tokens issued by AWS EKS have a fixed
# expiration time of 15 minutes
# source: https://aws.github.io/aws-eks-best-practices/security/docs/iam/#controlling-access-to-eks-clusters
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(minutes=EKS_KUBE_API_TOKEN_EXPIRATION)
# get cluster details
cluster_arn = cluster["cluster"]["arn"]
cluster_ca_cert = cluster["cluster"]["certificateAuthority"][
"data"
]
cluster_server = cluster["cluster"]["endpoint"]
# Create a client-side Kubernetes connector instance with the
# temporary Kubernetes credentials
try:
# Import libraries only when needed
from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import (
KubernetesAuthenticationMethods,
KubernetesServiceConnector,
KubernetesTokenConfig,
)
except ImportError as e:
raise RuntimeError(
f"The Kubernetes Service Connector functionality could not "
f"be used due to missing dependencies: {e}"
)
return KubernetesServiceConnector(
id=self.id,
name=connector_name,
auth_method=KubernetesAuthenticationMethods.TOKEN,
resource_type=resource_type,
config=KubernetesTokenConfig(
cluster_name=cluster_arn,
certificate_authority=cluster_ca_cert,
server=cluster_server,
token=user_token,
),
expires_at=expires_at,
)
raise ValueError(f"Unsupported resource type: {resource_type}")
account_id: str
property
readonly
Get the AWS account ID.
Returns:
Type | Description |
---|---|
str |
The AWS account ID. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the AWS account ID could not be determined. |
get_boto3_session(self, auth_method, resource_type=None, resource_id=None)
Get a boto3 session for the specified resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_method |
str |
The authentication method to use. |
required |
resource_type |
Optional[str] |
The resource type to get a boto3 session for. |
None |
resource_id |
Optional[str] |
The resource ID to get a boto3 session for. |
None |
Returns:
Type | Description |
---|---|
Tuple[boto3.Session, Optional[datetime.datetime]] |
A boto3 session for the specified resource and its expiration timestamp, if applicable. |
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def get_boto3_session(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
"""Get a boto3 session for the specified resource.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to get a boto3 session for.
resource_id: The resource ID to get a boto3 session for.
Returns:
A boto3 session for the specified resource and its expiration
timestamp, if applicable.
"""
# We maintain a cache of all sessions to avoid re-authenticating
# multiple times for the same resource
key = (auth_method, resource_type, resource_id)
if key in self._session_cache:
session, expires_at = self._session_cache[key]
if expires_at is None:
return session, None
# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
# check if the token expires in the near future
if expires_at > now + datetime.timedelta(
minutes=BOTO3_SESSION_EXPIRATION_BUFFER
):
return session, expires_at
logger.debug(
f"Creating boto3 session for auth method '{auth_method}', "
f"resource type '{resource_type}' and resource ID "
f"'{resource_id}'..."
)
session, expires_at = self._authenticate(
auth_method, resource_type, resource_id
)
self._session_cache[key] = (session, expires_at)
return session, expires_at
get_ecr_client(self)
Get an ECR client.
Exceptions:
Type | Description |
---|---|
ValueError |
If the service connector is not able to instantiate an ECR client. |
Returns:
Type | Description |
---|---|
botocore.client.BaseClient |
An ECR client. |
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def get_ecr_client(self) -> BaseClient:
"""Get an ECR client.
Raises:
ValueError: If the service connector is not able to instantiate an
ECR client.
Returns:
An ECR client.
"""
if self.resource_type and self.resource_type not in {
AWS_RESOURCE_TYPE,
DOCKER_REGISTRY_RESOURCE_TYPE,
}:
raise ValueError(
f"Unable to instantiate ECR client for a connector that is "
f"configured to provide access to a '{self.resource_type}' "
"resource type."
)
session, _ = self.get_boto3_session(
auth_method=self.auth_method,
resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
resource_id=self.config.region,
)
return session.client(
"ecr",
region_name=self.config.region,
endpoint_url=self.config.endpoint_url,
)
model_post_init(/, self, context)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
BaseModel |
The BaseModel instance. |
required |
context |
Any |
The context. |
required |
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
AWSSessionPolicy (AuthenticationConfig)
AWS session IAM policy configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSessionPolicy(AuthenticationConfig):
"""AWS session IAM policy configuration."""
policy_arns: Optional[List[str]] = Field(
default=None,
title="ARNs of the IAM managed policies that you want to use as a "
"managed session policy. The policies must exist in the same account "
"as the IAM user that is requesting temporary credentials.",
)
policy: Optional[str] = Field(
default=None,
title="An IAM policy in JSON format that you want to use as an inline "
"session policy",
)
FederationTokenAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy)
AWS federation token authentication config.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class FederationTokenAuthenticationConfig(
AWSSecretKeyConfig, AWSSessionPolicy
):
"""AWS federation token authentication config."""
IAMRoleAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy)
AWS IAM authentication config.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class IAMRoleAuthenticationConfig(AWSSecretKeyConfig, AWSSessionPolicy):
"""AWS IAM authentication config."""
role_arn: str = Field(
title="AWS IAM Role ARN",
)
STSToken (AWSSecretKey)
AWS STS token.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSToken(AWSSecretKey):
"""AWS STS token."""
aws_session_token: PlainSerializedSecretStr = Field(
title="AWS Session Token",
)
STSTokenConfig (AWSBaseConfig, STSToken)
AWS STS token authentication configuration.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSTokenConfig(AWSBaseConfig, STSToken):
"""AWS STS token authentication configuration."""
SessionTokenAuthenticationConfig (AWSSecretKeyConfig)
AWS session token authentication config.
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class SessionTokenAuthenticationConfig(AWSSecretKeyConfig):
"""AWS session token authentication config."""
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator)
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
"""
@property
def config(self) -> SagemakerStepOperatorConfig:
"""Returns the `SagemakerStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(SagemakerStepOperatorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the SageMaker step operator.
Returns:
The settings class.
"""
return SagemakerStepOperatorSettings
@property
def entrypoint_config_class(
self,
) -> Type[StepOperatorEntrypointConfiguration]:
"""Returns the entrypoint configuration class for this step operator.
Returns:
The entrypoint configuration class for this step operator.
"""
return SagemakerEntrypointConfiguration
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the SageMaker "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The SageMaker step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"SageMaker step operator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_docker_builds(
self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=SAGEMAKER_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
builds.append(build)
return builds
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
Raises:
RuntimeError: If the connector returns an object that is not a
`boto3.Session`.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
# Sagemaker does not allow environment variables longer than 512
# characters to be passed to Estimator steps. If an environment variable
# is longer than 512 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
env=environment,
size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
)
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args
# Get authenticated session
# Option 1: Service connector
boto_session: boto3.Session
if connector := self.get_connector():
boto_session = connector.connect()
if not isinstance(boto_session, boto3.Session):
raise RuntimeError(
f"Expected to receive a `boto3.Session` object from the "
f"linked connector, but got type `{type(boto_session)}`."
)
# Option 2: Implicit configuration
else:
boto_session = boto3.Session()
session = sagemaker.Session(
boto_session=boto_session, default_bucket=self.config.bucket
)
estimator_args.setdefault(
"instance_type", settings.instance_type or "ml.m5.large"
)
estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
# Create Estimator
estimator = sagemaker.estimator.Estimator(
image_name, self.config.role, **estimator_args
)
# SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
step_name = Client().get_run_step(info.step_run_id).name
training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
suffix = random_str(4)
unique_training_job_name = f"{training_job_name}-{suffix}"
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_training_job_name = unique_training_job_name.replace(
"_", "-"
)
# Construct training input object, if necessary
inputs = None
if isinstance(settings.input_data_s3_uri, str):
inputs = sagemaker.inputs.TrainingInput(
s3_data=settings.input_data_s3_uri
)
elif isinstance(settings.input_data_s3_uri, dict):
inputs = {}
for channel, s3_uri in settings.input_data_s3_uri.items():
inputs[channel] = sagemaker.inputs.TrainingInput(
s3_data=s3_uri
)
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": sanitized_training_job_name,
}
info.force_write_logs()
estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=sanitized_training_job_name,
)
config: SagemakerStepOperatorConfig
property
readonly
Returns the SagemakerStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
SagemakerStepOperatorConfig |
The configuration. |
entrypoint_config_class: Type[zenml.step_operators.step_operator_entrypoint_configuration.StepOperatorEntrypointConfiguration]
property
readonly
Returns the entrypoint configuration class for this step operator.
Returns:
Type | Description |
---|---|
Type[zenml.step_operators.step_operator_entrypoint_configuration.StepOperatorEntrypointConfiguration] |
The entrypoint configuration class for this step operator. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the SageMaker step operator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
get_docker_builds(self, deployment)
Gets the Docker builds required for the component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentBase |
The pipeline deployment for which to get the builds. |
required |
Returns:
Type | Description |
---|---|
List[BuildConfiguration] |
The required Docker builds. |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def get_docker_builds(
self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=SAGEMAKER_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
)
builds.append(build)
return builds
launch(self, info, entrypoint_command, environment)
Launches a step on SageMaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
environment |
Dict[str, str] |
Environment variables to set in the step operator environment. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the connector returns an object that is not a
|
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on SageMaker.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
Raises:
RuntimeError: If the connector returns an object that is not a
`boto3.Session`.
"""
if not info.config.resource_settings.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
# Sagemaker does not allow environment variables longer than 512
# characters to be passed to Estimator steps. If an environment variable
# is longer than 512 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
env=environment,
size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
)
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)
settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))
# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args
# Get authenticated session
# Option 1: Service connector
boto_session: boto3.Session
if connector := self.get_connector():
boto_session = connector.connect()
if not isinstance(boto_session, boto3.Session):
raise RuntimeError(
f"Expected to receive a `boto3.Session` object from the "
f"linked connector, but got type `{type(boto_session)}`."
)
# Option 2: Implicit configuration
else:
boto_session = boto3.Session()
session = sagemaker.Session(
boto_session=boto_session, default_bucket=self.config.bucket
)
estimator_args.setdefault(
"instance_type", settings.instance_type or "ml.m5.large"
)
estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
# Create Estimator
estimator = sagemaker.estimator.Estimator(
image_name, self.config.role, **estimator_args
)
# SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
step_name = Client().get_run_step(info.step_run_id).name
training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
suffix = random_str(4)
unique_training_job_name = f"{training_job_name}-{suffix}"
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_training_job_name = unique_training_job_name.replace(
"_", "-"
)
# Construct training input object, if necessary
inputs = None
if isinstance(settings.input_data_s3_uri, str):
inputs = sagemaker.inputs.TrainingInput(
s3_data=settings.input_data_s3_uri
)
elif isinstance(settings.input_data_s3_uri, dict):
inputs = {}
for channel, s3_uri in settings.input_data_s3_uri.items():
inputs[channel] = sagemaker.inputs.TrainingInput(
s3_data=s3_uri
)
experiment_config = {}
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": sanitized_training_job_name,
}
info.force_write_logs()
estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=sanitized_training_job_name,
)
sagemaker_step_operator_entrypoint_config
Entrypoint configuration for ZenML Sagemaker step operator.
SagemakerEntrypointConfiguration (StepOperatorEntrypointConfiguration)
Entrypoint configuration for ZenML Sagemaker step operator.
The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 512 characters allowed for Sagemaker Estimator steps from their individual components.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py
class SagemakerEntrypointConfiguration(StepOperatorEntrypointConfiguration):
"""Entrypoint configuration for ZenML Sagemaker step operator.
The only purpose of this entrypoint configuration is to reconstruct the
environment variables that exceed the maximum length of 512 characters
allowed for Sagemaker Estimator steps from their individual components.
"""
def run(self) -> None:
"""Runs the step."""
# Reconstruct the environment variables that exceed the maximum length
# of 512 characters from their individual chunks
reconstruct_environment_variables()
# Run the step
super().run()
run(self)
Runs the step.
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py
def run(self) -> None:
"""Runs the step."""
# Reconstruct the environment variables that exceed the maximum length
# of 512 characters from their individual chunks
reconstruct_environment_variables()
# Run the step
super().run()