Skip to content

Aws

zenml.integrations.aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = [
        "sagemaker>=2.117.0",
        "kubernetes",
        "aws-profile-manager",
    ]
    REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes"]

    @classmethod
    def activate(cls) -> None:
        """Activate the AWS integration."""
        from zenml.integrations.aws import service_connectors  # noqa

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.aws.flavors import (
            AWSContainerRegistryFlavor,
            SagemakerOrchestratorFlavor,
            SagemakerStepOperatorFlavor,
        )

        return [
            AWSContainerRegistryFlavor,
            SagemakerStepOperatorFlavor,
            SagemakerOrchestratorFlavor,
        ]

activate() classmethod

Activate the AWS integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the AWS integration."""
    from zenml.integrations.aws import service_connectors  # noqa

flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.aws.flavors import (
        AWSContainerRegistryFlavor,
        SagemakerOrchestratorFlavor,
        SagemakerStepOperatorFlavor,
    )

    return [
        AWSContainerRegistryFlavor,
        SagemakerStepOperatorFlavor,
        SagemakerOrchestratorFlavor,
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry)

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    @property
    def config(self) -> AWSContainerRegistryConfig:
        """Returns the `AWSContainerRegistryConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSContainerRegistryConfig, self._config)

    def _get_region(self) -> str:
        """Parses the AWS region from the registry URI.

        Raises:
            RuntimeError: If the region parsing fails due to an invalid URI.

        Returns:
            The region string.
        """
        match = re.fullmatch(
            r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
        )
        if not match:
            raise RuntimeError(
                f"Unable to parse region from ECR URI {self.config.uri}."
            )

        return match.group(1)

    def _get_ecr_client(self) -> BaseClient:
        """Get an ECR client.

        If this container registry is configured with an AWS service connector,
        we use that connector to create an authenticated client. Otherwise
        local AWS credentials will be used.

        Returns:
            An ECR client.
        """
        if self.connector:
            try:
                model = Client().get_service_connector(self.connector)
                connector = service_connector_registry.instantiate_connector(
                    model=model
                )
                assert isinstance(connector, AWSServiceConnector)
                return connector.get_ecr_client()
            except Exception as e:
                logger.error(
                    "Unable to get ECR client from service connector: %s",
                    str(e),
                )

        return boto3.Session().client(
            "ecr",
            region_name=self._get_region(),
        )

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        # Find repository name from image name
        match = re.search(f"{self.config.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")
        repo_name = match.group(1)

        client = self._get_ecr_client()
        try:
            response = client.describe_repositories()
        except (BotoCoreError, ClientError):
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could not find any repositories because "
                "your local AWS credentials are not set. We will try to push "
                "anyway, but in case it fails you need to create a repository "
                f"named `{repo_name}`."
            )
            return

        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(
            image_name.startswith(f"{uri}:") for uri in repo_uris
        )
        if not repo_exists:
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for run a pipeline "
            "using a remote orchestrator, ZenML will automatically build a "
            f"docker image called `{self.config.uri}/zenml:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create a "
            f"repository called `zenml` inside your Amazon ECR."
        )
config: AWSContainerRegistryConfig property readonly

Returns the AWSContainerRegistryConfig config.

Returns:

Type Description
AWSContainerRegistryConfig

The configuration.

post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    # Find repository name from image name
    match = re.search(f"{self.config.uri}/(.*):.*", image_name)
    if not match:
        raise ValueError(f"Invalid docker image name '{image_name}'.")
    repo_name = match.group(1)

    client = self._get_ecr_client()
    try:
        response = client.describe_repositories()
    except (BotoCoreError, ClientError):
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could not find any repositories because "
            "your local AWS credentials are not set. We will try to push "
            "anyway, but in case it fails you need to create a repository "
            f"named `{repo_name}`."
        )
        return

    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(
        image_name.startswith(f"{uri}:") for uri in repo_uris
    )
    if not repo_exists:
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )

flavors special

AWS integration flavors.

aws_container_registry_flavor

AWS container registry flavor.

AWSContainerRegistryConfig (BaseContainerRegistryConfig)

Configuration for AWS Container Registry.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
    """Configuration for AWS Container Registry."""

    @field_validator("uri")
    @classmethod
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@field_validator("uri")
@classmethod
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)

AWS Container Registry flavor.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
    """AWS Container Registry flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_CONTAINER_REGISTRY_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(
            connector_type=AWS_CONNECTOR_TYPE,
            resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
            resource_id_attr="uri",
        )

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png"

    @property
    def config_class(self) -> Type[AWSContainerRegistryConfig]:
        """Config class for this flavor.

        Returns:
            The config class.
        """
        return AWSContainerRegistryConfig

    @property
    def implementation_class(self) -> Type["AWSContainerRegistry"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.container_registries import (
            AWSContainerRegistry,
        )

        return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSContainerRegistry] property readonly

Implementation class.

Returns:

Type Description
Type[AWSContainerRegistry]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

sagemaker_orchestrator_flavor

Amazon SageMaker orchestrator flavor.

SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings)

Config for the Sagemaker orchestrator.

There are three ways to authenticate to AWS: - By connecting a ServiceConnector to the orchestrator, - By configuring explicit AWS credentials aws_access_key_id, aws_secret_access_key, and optional aws_auth_role_arn, - If none of the above are provided, unspecified credentials will be loaded from the default AWS config.

Attributes:

Name Type Description
synchronous bool

If True, the client running a pipeline using this orchestrator waits until all steps finish running. If False, the client returns immediately and the pipeline is executed asynchronously. Defaults to True.

execution_role str

The IAM role ARN to use for the pipeline.

aws_access_key_id Optional[str]

The AWS access key ID to use to authenticate to AWS. If not provided, the value from the default AWS config will be used.

aws_secret_access_key Optional[str]

The AWS secret access key to use to authenticate to AWS. If not provided, the value from the default AWS config will be used.

aws_profile Optional[str]

The AWS profile to use for authentication if not using service connectors or explicit credentials. If not provided, the default profile will be used.

aws_auth_role_arn Optional[str]

The ARN of an intermediate IAM role to assume when authenticating to AWS.

region Optional[str]

The AWS region where the processing job will be run. If not provided, the value from the default AWS config will be used.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig(
    BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
    """Config for the Sagemaker orchestrator.

    There are three ways to authenticate to AWS:
    - By connecting a `ServiceConnector` to the orchestrator,
    - By configuring explicit AWS credentials `aws_access_key_id`,
        `aws_secret_access_key`, and optional `aws_auth_role_arn`,
    - If none of the above are provided, unspecified credentials will be
        loaded from the default AWS config.

    Attributes:
        synchronous: If `True`, the client running a pipeline using this
            orchestrator waits until all steps finish running. If `False`,
            the client returns immediately and the pipeline is executed
            asynchronously. Defaults to `True`.
        execution_role: The IAM role ARN to use for the pipeline.
        aws_access_key_id: The AWS access key ID to use to authenticate to AWS.
            If not provided, the value from the default AWS config will be used.
        aws_secret_access_key: The AWS secret access key to use to authenticate
            to AWS. If not provided, the value from the default AWS config will
            be used.
        aws_profile: The AWS profile to use for authentication if not using
            service connectors or explicit credentials. If not provided, the
            default profile will be used.
        aws_auth_role_arn: The ARN of an intermediate IAM role to assume when
            authenticating to AWS.
        region: The AWS region where the processing job will be run. If not
            provided, the value from the default AWS config will be used.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format:
            "sagemaker-{region}-{aws-account-id}".
    """

    synchronous: bool = True
    execution_role: str
    aws_access_key_id: Optional[str] = SecretField(default=None)
    aws_secret_access_key: Optional[str] = SecretField(default=None)
    aws_profile: Optional[str] = None
    aws_auth_role_arn: Optional[str] = None
    region: Optional[str] = None
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True

    @property
    def is_synchronous(self) -> bool:
        """Whether the orchestrator runs synchronous or not.

        Returns:
            Whether the orchestrator runs synchronous or not.
        """
        return self.synchronous
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

is_synchronous: bool property readonly

Whether the orchestrator runs synchronous or not.

Returns:

Type Description
bool

Whether the orchestrator runs synchronous or not.

SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)

Flavor for the Sagemaker orchestrator.

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
    """Flavor for the Sagemaker orchestrator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE)

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png"

    @property
    def config_class(self) -> Type[SagemakerOrchestratorConfig]:
        """Returns SagemakerOrchestratorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerOrchestratorConfig

    @property
    def implementation_class(self) -> Type["SagemakerOrchestrator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.orchestrators import SagemakerOrchestrator

        return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] property readonly

Returns SagemakerOrchestratorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerOrchestrator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerOrchestrator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

SagemakerOrchestratorSettings (BaseSettings)

Settings for the Sagemaker orchestrator.

Attributes:

Name Type Description
instance_type str

The instance type to use for the processing job.

processor_role Optional[str]

The IAM role to use for the step execution on a Processor.

volume_size_in_gb int

The size of the EBS volume to use for the processing job.

max_runtime_in_seconds int

The maximum runtime in seconds for the processing job.

processor_tags Dict[str, str]

Tags to apply to the Processor assigned to the step.

processor_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Processor for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor for a full list of arguments. For processor_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

input_data_s3_mode str

How data is made available to the container. Two possible input modes: File, Pipe.

input_data_s3_uri Union[str, Dict[str, str]]

S3 URI where data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with input_data_s3_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3.

output_data_s3_mode str

How data is uploaded to the S3 bucket. Two possible output modes: EndOfJob, Continuous.

output_data_s3_uri Union[str, Dict[str, str]]

S3 URI where data is uploaded after or during processing run. e.g. s3://my-bucket/my-data/output. How data will be made available to the container is configured with output_data_s3_mode. Two possible input types: - str: S3 location where data will be uploaded from a local folder named /opt/ml/processing/output/data. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. output_one, output_two) where specific parts of the data are stored locally for S3 upload. Data must be available locally in /opt/ml/processing/output/data/.

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
    """Settings for the Sagemaker orchestrator.

    Attributes:
        instance_type: The instance type to use for the processing job.
        processor_role: The IAM role to use for the step execution on a Processor.
        volume_size_in_gb: The size of the EBS volume to use for the processing
            job.
        max_runtime_in_seconds: The maximum runtime in seconds for the
            processing job.
        processor_tags: Tags to apply to the Processor assigned to the step.
        processor_args: Arguments that are directly passed to the SageMaker
            Processor for a specific step, allowing for overriding the default
            settings provided when configuring the component. See
            https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor
            for a full list of arguments.
            For processor_args.instance_type, check
            https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
            for a list of available instance types.
        input_data_s3_mode: How data is made available to the container.
            Two possible input modes: File, Pipe.
        input_data_s3_uri: S3 URI where data is located if not locally,
            e.g. s3://my-bucket/my-data/train. How data will be made available
            to the container is configured with input_data_s3_mode. Two possible
            input types:
                - str: S3 location where training data is saved.
                - Dict[str, str]: (ChannelName, S3Location) which represent
                    channels (e.g. training, validation, testing) where
                    specific parts of the data are saved in S3.
        output_data_s3_mode: How data is uploaded to the S3 bucket.
            Two possible output modes: EndOfJob, Continuous.
        output_data_s3_uri: S3 URI where data is uploaded after or during processing run.
            e.g. s3://my-bucket/my-data/output. How data will be made available
            to the container is configured with output_data_s3_mode. Two possible
            input types:
                - str: S3 location where data will be uploaded from a local folder
                    named /opt/ml/processing/output/data.
                - Dict[str, str]: (ChannelName, S3Location) which represent
                    channels (e.g. output_one, output_two) where
                    specific parts of the data are stored locally for S3 upload.
                    Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
    """

    instance_type: str = "ml.t3.medium"
    processor_role: Optional[str] = None
    volume_size_in_gb: int = 30
    max_runtime_in_seconds: int = 86400
    processor_tags: Dict[str, str] = {}

    processor_args: Dict[str, Any] = {}
    input_data_s3_mode: str = "File"
    input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
        default=None, union_mode="left_to_right"
    )

    output_data_s3_mode: str = "EndOfJob"
    output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
        default=None, union_mode="left_to_right"
    )

sagemaker_step_operator_flavor

Amazon SageMaker step operator flavor.

SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings)

Config for the Sagemaker step operator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig(
    BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
    """Config for the Sagemaker step operator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
    """

    role: str
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)

Flavor for the Sagemaker step operator.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
    """Flavor for the Sagemaker step operator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE)

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png"

    @property
    def config_class(self) -> Type[SagemakerStepOperatorConfig]:
        """Returns SagemakerStepOperatorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerStepOperatorConfig

    @property
    def implementation_class(self) -> Type["SagemakerStepOperator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.step_operators import SagemakerStepOperator

        return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] property readonly

Returns SagemakerStepOperatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerStepOperator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerStepOperator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

SagemakerStepOperatorSettings (BaseSettings)

Settings for the Sagemaker step operator.

Attributes:

Name Type Description
experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

input_data_s3_uri Union[str, Dict[str, str]]

S3 URI where training data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with estimator_args.input_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3.

estimator_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Estimator. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For estimator_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
    """Settings for the Sagemaker step operator.

    Attributes:
        experiment_name: The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
        input_data_s3_uri: S3 URI where training data is located if not locally,
            e.g. s3://my-bucket/my-data/train. How data will be made available
            to the container is configured with estimator_args.input_mode. Two possible
            input types:
                - str: S3 location where training data is saved.
                - Dict[str, str]: (ChannelName, S3Location) which represent
                    channels (e.g. training, validation, testing) where
                    specific parts of the data are saved in S3.
        estimator_args: Arguments that are directly passed to the SageMaker
            Estimator. See
            https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
            for a full list of arguments.
            For estimator_args.instance_type, check
            https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
            for a list of available instance types.

    """

    instance_type: Optional[str] = None
    experiment_name: Optional[str] = None
    input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
        default=None, union_mode="left_to_right"
    )
    estimator_args: Dict[str, Any] = {}

    _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
        "instance_type"
    )

orchestrators special

AWS Sagemaker orchestrator.

sagemaker_orchestrator

Implementation of the SageMaker orchestrator.

SagemakerOrchestrator (ContainerizedOrchestrator)

Orchestrator responsible for running pipelines on Sagemaker.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(ContainerizedOrchestrator):
    """Orchestrator responsible for running pipelines on Sagemaker."""

    @property
    def config(self) -> SagemakerOrchestratorConfig:
        """Returns the `SagemakerOrchestratorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerOrchestratorConfig, self._config)

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        In the remote case, checks that the stack contains a container registry,
        image builder and only remote components.

        Returns:
            A `StackValidator` instance.
        """

        def _validate_remote_components(
            stack: "Stack",
        ) -> Tuple[bool, str]:
            for component in stack.components.values():
                if not component.config.is_local:
                    continue

                return False, (
                    f"The Sagemaker orchestrator runs pipelines remotely, "
                    f"but the '{component.name}' {component.type.value} is "
                    "a local stack component and will not be available in "
                    "the Sagemaker step.\nPlease ensure that you always "
                    "use non-local stack components with the Sagemaker "
                    "orchestrator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_orchestrator_run_id(self) -> str:
        """Returns the run id of the active orchestrator run.

        Important: This needs to be a unique ID and return the same value for
        all steps of a pipeline run.

        Returns:
            The orchestrator run id.

        Raises:
            RuntimeError: If the run id cannot be read from the environment.
        """
        try:
            return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
        except KeyError:
            raise RuntimeError(
                "Unable to read run id from environment variable "
                f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
            )

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Sagemaker orchestrator.

        Returns:
            The settings class.
        """
        return SagemakerOrchestratorSettings

    def prepare_or_run_pipeline(
        self,
        deployment: "PipelineDeploymentResponse",
        stack: "Stack",
        environment: Dict[str, str],
    ) -> None:
        """Prepares or runs a pipeline on Sagemaker.

        Args:
            deployment: The deployment to prepare or run.
            stack: The stack to run on.
            environment: Environment variables to set in the orchestration
                environment.

        Raises:
            RuntimeError: If a connector is used that does not return a
                `boto3.Session` object.
            TypeError: If the network_config passed is not compatible with the
                AWS SageMaker NetworkConfig class.
        """
        if deployment.schedule:
            logger.warning(
                "The Sagemaker Orchestrator currently does not support the "
                "use of schedules. The `schedule` will be ignored "
                "and the pipeline will be run immediately."
            )

        # sagemaker requires pipelineName to use alphanum and hyphens only
        unsanitized_orchestrator_run_name = get_orchestrator_run_name(
            pipeline_name=deployment.pipeline_configuration.name
        )
        # replace all non-alphanum and non-hyphens with hyphens
        orchestrator_run_name = re.sub(
            r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
        )

        # Get authenticated session
        # Option 1: Service connector
        boto_session: boto3.Session
        if connector := self.get_connector():
            boto_session = connector.connect()
            if not isinstance(boto_session, boto3.Session):
                raise RuntimeError(
                    f"Expected to receive a `boto3.Session` object from the "
                    f"linked connector, but got type `{type(boto_session)}`."
                )
        # Option 2: Explicit configuration
        # Args that are not provided will be taken from the default AWS config.
        else:
            boto_session = boto3.Session(
                aws_access_key_id=self.config.aws_access_key_id,
                aws_secret_access_key=self.config.aws_secret_access_key,
                region_name=self.config.region,
                profile_name=self.config.aws_profile,
            )
            # If a role ARN is provided for authentication, assume the role
            if self.config.aws_auth_role_arn:
                sts = boto_session.client("sts")
                response = sts.assume_role(
                    RoleArn=self.config.aws_auth_role_arn,
                    RoleSessionName="zenml-sagemaker-orchestrator",
                )
                credentials = response["Credentials"]
                boto_session = boto3.Session(
                    aws_access_key_id=credentials["AccessKeyId"],
                    aws_secret_access_key=credentials["SecretAccessKey"],
                    aws_session_token=credentials["SessionToken"],
                    region_name=self.config.region,
                )
        session = sagemaker.Session(
            boto_session=boto_session, default_bucket=self.config.bucket
        )

        # Sagemaker does not allow environment variables longer than 256
        # characters to be passed to Processor steps. If an environment variable
        # is longer than 256 characters, we split it into multiple environment
        # variables (chunks) and re-construct it on the other side using the
        # custom entrypoint configuration.
        split_environment_variables(
            size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
            env=environment,
        )

        sagemaker_steps = []
        for step_name, step in deployment.step_configurations.items():
            image = self.get_image(deployment=deployment, step_name=step_name)
            command = SagemakerEntrypointConfiguration.get_entrypoint_command()
            arguments = (
                SagemakerEntrypointConfiguration.get_entrypoint_arguments(
                    step_name=step_name, deployment_id=deployment.id
                )
            )
            entrypoint = command + arguments

            step_settings = cast(
                SagemakerOrchestratorSettings, self.get_settings(step)
            )

            environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
                ExecutionVariables.PIPELINE_EXECUTION_ARN
            )

            # Retrieve Processor arguments provided in the Step settings.
            processor_args_for_step = step_settings.processor_args or {}

            # Set default values from configured orchestrator Component to arguments
            # to be used when they are not present in processor_args.
            processor_args_for_step.setdefault(
                "instance_type", step_settings.instance_type
            )
            processor_args_for_step.setdefault(
                "role",
                step_settings.processor_role or self.config.execution_role,
            )
            processor_args_for_step.setdefault(
                "volume_size_in_gb", step_settings.volume_size_in_gb
            )
            processor_args_for_step.setdefault(
                "max_runtime_in_seconds", step_settings.max_runtime_in_seconds
            )
            processor_args_for_step.setdefault(
                "tags",
                [
                    {"Key": key, "Value": value}
                    for key, value in step_settings.processor_tags.items()
                ]
                if step_settings.processor_tags
                else None,
            )

            # Set values that cannot be overwritten
            processor_args_for_step["image_uri"] = image
            processor_args_for_step["instance_count"] = 1
            processor_args_for_step["sagemaker_session"] = session
            processor_args_for_step["entrypoint"] = entrypoint
            processor_args_for_step["base_job_name"] = orchestrator_run_name
            processor_args_for_step["env"] = environment

            # Convert network_config to sagemaker.network.NetworkConfig if present
            network_config = processor_args_for_step.get("network_config")
            if network_config and isinstance(network_config, dict):
                try:
                    processor_args_for_step["network_config"] = NetworkConfig(
                        **network_config
                    )
                except TypeError:
                    # If the network_config passed is not compatible with the NetworkConfig class,
                    # raise a more informative error.
                    raise TypeError(
                        "Expected a sagemaker.network.NetworkConfig compatible object for the network_config argument, "
                        "but the network_config processor argument is invalid."
                        "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
                        "for more information about the NetworkConfig class."
                    )

            # Construct S3 inputs to container for step
            inputs = None

            if step_settings.input_data_s3_uri is None:
                pass
            elif isinstance(step_settings.input_data_s3_uri, str):
                inputs = [
                    ProcessingInput(
                        source=step_settings.input_data_s3_uri,
                        destination="/opt/ml/processing/input/data",
                        s3_input_mode=step_settings.input_data_s3_mode,
                    )
                ]
            elif isinstance(step_settings.input_data_s3_uri, dict):
                inputs = []
                for channel, s3_uri in step_settings.input_data_s3_uri.items():
                    inputs.append(
                        ProcessingInput(
                            source=s3_uri,
                            destination=f"/opt/ml/processing/input/data/{channel}",
                            s3_input_mode=step_settings.input_data_s3_mode,
                        )
                    )

            # Construct S3 outputs from container for step
            outputs = None

            if step_settings.output_data_s3_uri is None:
                pass
            elif isinstance(step_settings.output_data_s3_uri, str):
                outputs = [
                    ProcessingOutput(
                        source="/opt/ml/processing/output/data",
                        destination=step_settings.output_data_s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                ]
            elif isinstance(step_settings.output_data_s3_uri, dict):
                outputs = []
                for (
                    channel,
                    s3_uri,
                ) in step_settings.output_data_s3_uri.items():
                    outputs.append(
                        ProcessingOutput(
                            source=f"/opt/ml/processing/output/data/{channel}",
                            destination=s3_uri,
                            s3_upload_mode=step_settings.output_data_s3_mode,
                        )
                    )

            # Create Processor and ProcessingStep
            processor = sagemaker.processing.Processor(
                **processor_args_for_step
            )
            sagemaker_step = ProcessingStep(
                name=step_name,
                processor=processor,
                depends_on=step.spec.upstream_steps,
                inputs=inputs,
                outputs=outputs,
            )
            sagemaker_steps.append(sagemaker_step)

        # construct the pipeline from the sagemaker_steps
        pipeline = Pipeline(
            name=orchestrator_run_name,
            steps=sagemaker_steps,
            sagemaker_session=session,
        )

        pipeline.create(role_arn=self.config.execution_role)
        pipeline_execution = pipeline.start()
        logger.warning(
            "Steps can take 5-15 minutes to start running "
            "when using the Sagemaker Orchestrator."
        )

        # mainly for testing purposes, we wait for the pipeline to finish
        if self.config.synchronous:
            logger.info(
                "Executing synchronously. Waiting for pipeline to finish... \n"
                "At this point you can `Ctrl-C` out without cancelling the execution."
            )
            try:
                pipeline_execution.wait(
                    delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
                )
                logger.info("Pipeline completed successfully.")
            except WaiterError:
                raise RuntimeError(
                    "Timed out while waiting for pipeline execution to finish. For long-running "
                    "pipelines we recommend configuring your orchestrator for asynchronous execution. "
                    "The following command does this for you: \n"
                    f"`zenml orchestrator update {self.name} --synchronous=False`"
                )

    def _get_region_name(self) -> str:
        """Returns the AWS region name.

        Returns:
            The region name.

        Raises:
            RuntimeError: If the region name cannot be retrieved.
        """
        try:
            return cast(str, sagemaker.Session().boto_region_name)
        except Exception as e:
            raise RuntimeError(
                "Unable to get region name. Please ensure that you have "
                "configured your AWS credentials correctly."
            ) from e

    def get_pipeline_run_metadata(
        self, run_id: UUID
    ) -> Dict[str, "MetadataType"]:
        """Get general component-specific metadata for a pipeline run.

        Args:
            run_id: The ID of the pipeline run.

        Returns:
            A dictionary of metadata.
        """
        run_metadata: Dict[str, "MetadataType"] = {
            "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
        }
        try:
            region_name = self._get_region_name()
        except RuntimeError:
            logger.warning("Unable to get region name from AWS Sagemaker.")
            return run_metadata

        aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
        orchestrator_logs_url = (
            f"https://{region_name}.console.aws.amazon.com/"
            f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
            f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
            f"$3Dpipelines-{aws_run_id}-"
        )
        run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
        return run_metadata
config: SagemakerOrchestratorConfig property readonly

Returns the SagemakerOrchestratorConfig config.

Returns:

Type Description
SagemakerOrchestratorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Sagemaker orchestrator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

In the remote case, checks that the stack contains a container registry, image builder and only remote components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

get_orchestrator_run_id(self)

Returns the run id of the active orchestrator run.

Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.

Returns:

Type Description
str

The orchestrator run id.

Exceptions:

Type Description
RuntimeError

If the run id cannot be read from the environment.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
    """Returns the run id of the active orchestrator run.

    Important: This needs to be a unique ID and return the same value for
    all steps of a pipeline run.

    Returns:
        The orchestrator run id.

    Raises:
        RuntimeError: If the run id cannot be read from the environment.
    """
    try:
        return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
    except KeyError:
        raise RuntimeError(
            "Unable to read run id from environment variable "
            f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
        )
get_pipeline_run_metadata(self, run_id)

Get general component-specific metadata for a pipeline run.

Parameters:

Name Type Description Default
run_id UUID

The ID of the pipeline run.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
    self, run_id: UUID
) -> Dict[str, "MetadataType"]:
    """Get general component-specific metadata for a pipeline run.

    Args:
        run_id: The ID of the pipeline run.

    Returns:
        A dictionary of metadata.
    """
    run_metadata: Dict[str, "MetadataType"] = {
        "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
    }
    try:
        region_name = self._get_region_name()
    except RuntimeError:
        logger.warning("Unable to get region name from AWS Sagemaker.")
        return run_metadata

    aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
    orchestrator_logs_url = (
        f"https://{region_name}.console.aws.amazon.com/"
        f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
        f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
        f"$3Dpipelines-{aws_run_id}-"
    )
    run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
    return run_metadata
prepare_or_run_pipeline(self, deployment, stack, environment)

Prepares or runs a pipeline on Sagemaker.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponse

The deployment to prepare or run.

required
stack Stack

The stack to run on.

required
environment Dict[str, str]

Environment variables to set in the orchestration environment.

required

Exceptions:

Type Description
RuntimeError

If a connector is used that does not return a boto3.Session object.

TypeError

If the network_config passed is not compatible with the AWS SageMaker NetworkConfig class.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
    self,
    deployment: "PipelineDeploymentResponse",
    stack: "Stack",
    environment: Dict[str, str],
) -> None:
    """Prepares or runs a pipeline on Sagemaker.

    Args:
        deployment: The deployment to prepare or run.
        stack: The stack to run on.
        environment: Environment variables to set in the orchestration
            environment.

    Raises:
        RuntimeError: If a connector is used that does not return a
            `boto3.Session` object.
        TypeError: If the network_config passed is not compatible with the
            AWS SageMaker NetworkConfig class.
    """
    if deployment.schedule:
        logger.warning(
            "The Sagemaker Orchestrator currently does not support the "
            "use of schedules. The `schedule` will be ignored "
            "and the pipeline will be run immediately."
        )

    # sagemaker requires pipelineName to use alphanum and hyphens only
    unsanitized_orchestrator_run_name = get_orchestrator_run_name(
        pipeline_name=deployment.pipeline_configuration.name
    )
    # replace all non-alphanum and non-hyphens with hyphens
    orchestrator_run_name = re.sub(
        r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
    )

    # Get authenticated session
    # Option 1: Service connector
    boto_session: boto3.Session
    if connector := self.get_connector():
        boto_session = connector.connect()
        if not isinstance(boto_session, boto3.Session):
            raise RuntimeError(
                f"Expected to receive a `boto3.Session` object from the "
                f"linked connector, but got type `{type(boto_session)}`."
            )
    # Option 2: Explicit configuration
    # Args that are not provided will be taken from the default AWS config.
    else:
        boto_session = boto3.Session(
            aws_access_key_id=self.config.aws_access_key_id,
            aws_secret_access_key=self.config.aws_secret_access_key,
            region_name=self.config.region,
            profile_name=self.config.aws_profile,
        )
        # If a role ARN is provided for authentication, assume the role
        if self.config.aws_auth_role_arn:
            sts = boto_session.client("sts")
            response = sts.assume_role(
                RoleArn=self.config.aws_auth_role_arn,
                RoleSessionName="zenml-sagemaker-orchestrator",
            )
            credentials = response["Credentials"]
            boto_session = boto3.Session(
                aws_access_key_id=credentials["AccessKeyId"],
                aws_secret_access_key=credentials["SecretAccessKey"],
                aws_session_token=credentials["SessionToken"],
                region_name=self.config.region,
            )
    session = sagemaker.Session(
        boto_session=boto_session, default_bucket=self.config.bucket
    )

    # Sagemaker does not allow environment variables longer than 256
    # characters to be passed to Processor steps. If an environment variable
    # is longer than 256 characters, we split it into multiple environment
    # variables (chunks) and re-construct it on the other side using the
    # custom entrypoint configuration.
    split_environment_variables(
        size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
        env=environment,
    )

    sagemaker_steps = []
    for step_name, step in deployment.step_configurations.items():
        image = self.get_image(deployment=deployment, step_name=step_name)
        command = SagemakerEntrypointConfiguration.get_entrypoint_command()
        arguments = (
            SagemakerEntrypointConfiguration.get_entrypoint_arguments(
                step_name=step_name, deployment_id=deployment.id
            )
        )
        entrypoint = command + arguments

        step_settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(step)
        )

        environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
            ExecutionVariables.PIPELINE_EXECUTION_ARN
        )

        # Retrieve Processor arguments provided in the Step settings.
        processor_args_for_step = step_settings.processor_args or {}

        # Set default values from configured orchestrator Component to arguments
        # to be used when they are not present in processor_args.
        processor_args_for_step.setdefault(
            "instance_type", step_settings.instance_type
        )
        processor_args_for_step.setdefault(
            "role",
            step_settings.processor_role or self.config.execution_role,
        )
        processor_args_for_step.setdefault(
            "volume_size_in_gb", step_settings.volume_size_in_gb
        )
        processor_args_for_step.setdefault(
            "max_runtime_in_seconds", step_settings.max_runtime_in_seconds
        )
        processor_args_for_step.setdefault(
            "tags",
            [
                {"Key": key, "Value": value}
                for key, value in step_settings.processor_tags.items()
            ]
            if step_settings.processor_tags
            else None,
        )

        # Set values that cannot be overwritten
        processor_args_for_step["image_uri"] = image
        processor_args_for_step["instance_count"] = 1
        processor_args_for_step["sagemaker_session"] = session
        processor_args_for_step["entrypoint"] = entrypoint
        processor_args_for_step["base_job_name"] = orchestrator_run_name
        processor_args_for_step["env"] = environment

        # Convert network_config to sagemaker.network.NetworkConfig if present
        network_config = processor_args_for_step.get("network_config")
        if network_config and isinstance(network_config, dict):
            try:
                processor_args_for_step["network_config"] = NetworkConfig(
                    **network_config
                )
            except TypeError:
                # If the network_config passed is not compatible with the NetworkConfig class,
                # raise a more informative error.
                raise TypeError(
                    "Expected a sagemaker.network.NetworkConfig compatible object for the network_config argument, "
                    "but the network_config processor argument is invalid."
                    "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
                    "for more information about the NetworkConfig class."
                )

        # Construct S3 inputs to container for step
        inputs = None

        if step_settings.input_data_s3_uri is None:
            pass
        elif isinstance(step_settings.input_data_s3_uri, str):
            inputs = [
                ProcessingInput(
                    source=step_settings.input_data_s3_uri,
                    destination="/opt/ml/processing/input/data",
                    s3_input_mode=step_settings.input_data_s3_mode,
                )
            ]
        elif isinstance(step_settings.input_data_s3_uri, dict):
            inputs = []
            for channel, s3_uri in step_settings.input_data_s3_uri.items():
                inputs.append(
                    ProcessingInput(
                        source=s3_uri,
                        destination=f"/opt/ml/processing/input/data/{channel}",
                        s3_input_mode=step_settings.input_data_s3_mode,
                    )
                )

        # Construct S3 outputs from container for step
        outputs = None

        if step_settings.output_data_s3_uri is None:
            pass
        elif isinstance(step_settings.output_data_s3_uri, str):
            outputs = [
                ProcessingOutput(
                    source="/opt/ml/processing/output/data",
                    destination=step_settings.output_data_s3_uri,
                    s3_upload_mode=step_settings.output_data_s3_mode,
                )
            ]
        elif isinstance(step_settings.output_data_s3_uri, dict):
            outputs = []
            for (
                channel,
                s3_uri,
            ) in step_settings.output_data_s3_uri.items():
                outputs.append(
                    ProcessingOutput(
                        source=f"/opt/ml/processing/output/data/{channel}",
                        destination=s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                )

        # Create Processor and ProcessingStep
        processor = sagemaker.processing.Processor(
            **processor_args_for_step
        )
        sagemaker_step = ProcessingStep(
            name=step_name,
            processor=processor,
            depends_on=step.spec.upstream_steps,
            inputs=inputs,
            outputs=outputs,
        )
        sagemaker_steps.append(sagemaker_step)

    # construct the pipeline from the sagemaker_steps
    pipeline = Pipeline(
        name=orchestrator_run_name,
        steps=sagemaker_steps,
        sagemaker_session=session,
    )

    pipeline.create(role_arn=self.config.execution_role)
    pipeline_execution = pipeline.start()
    logger.warning(
        "Steps can take 5-15 minutes to start running "
        "when using the Sagemaker Orchestrator."
    )

    # mainly for testing purposes, we wait for the pipeline to finish
    if self.config.synchronous:
        logger.info(
            "Executing synchronously. Waiting for pipeline to finish... \n"
            "At this point you can `Ctrl-C` out without cancelling the execution."
        )
        try:
            pipeline_execution.wait(
                delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
            )
            logger.info("Pipeline completed successfully.")
        except WaiterError:
            raise RuntimeError(
                "Timed out while waiting for pipeline execution to finish. For long-running "
                "pipelines we recommend configuring your orchestrator for asynchronous execution. "
                "The following command does this for you: \n"
                f"`zenml orchestrator update {self.name} --synchronous=False`"
            )

sagemaker_orchestrator_entrypoint_config

Entrypoint configuration for ZenML Sagemaker pipeline steps.

SagemakerEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for ZenML Sagemaker pipeline steps.

The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 256 characters allowed for Sagemaker Processor steps from their individual components.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py
class SagemakerEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for ZenML Sagemaker pipeline steps.

    The only purpose of this entrypoint configuration is to reconstruct the
    environment variables that exceed the maximum length of 256 characters
    allowed for Sagemaker Processor steps from their individual components.
    """

    def run(self) -> None:
        """Runs the step."""
        # Reconstruct the environment variables that exceed the maximum length
        # of 256 characters from their individual chunks
        reconstruct_environment_variables()

        # Run the step
        super().run()
run(self)

Runs the step.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py
def run(self) -> None:
    """Runs the step."""
    # Reconstruct the environment variables that exceed the maximum length
    # of 256 characters from their individual chunks
    reconstruct_environment_variables()

    # Run the step
    super().run()

service_connectors special

AWS Service Connector.

aws_service_connector

AWS Service Connector.

The AWS Service Connector implements various authentication methods for AWS services:

  • Explicit AWS secret key (access key, secret key)
  • Explicit AWS STS tokens (access key, secret key, session token)
  • IAM roles (i.e. generating temporary STS tokens on the fly by assuming an IAM role)
  • IAM user federation tokens
  • STS Session tokens
AWSAuthenticationMethods (StrEnum)

AWS Authentication methods.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSAuthenticationMethods(StrEnum):
    """AWS Authentication methods."""

    IMPLICIT = "implicit"
    SECRET_KEY = "secret-key"
    STS_TOKEN = "sts-token"
    IAM_ROLE = "iam-role"
    SESSION_TOKEN = "session-token"
    FEDERATION_TOKEN = "federation-token"
AWSBaseConfig (AuthenticationConfig)

AWS base configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSBaseConfig(AuthenticationConfig):
    """AWS base configuration."""

    region: str = Field(
        title="AWS Region",
    )
    endpoint_url: Optional[str] = Field(
        default=None,
        title="AWS Endpoint URL",
    )
AWSImplicitConfig (AWSBaseConfig, AWSSessionPolicy)

AWS implicit configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSImplicitConfig(AWSBaseConfig, AWSSessionPolicy):
    """AWS implicit configuration."""

    profile_name: Optional[str] = Field(
        default=None,
        title="AWS Profile Name",
    )
    role_arn: Optional[str] = Field(
        default=None,
        title="Optional AWS IAM Role ARN to assume",
    )
AWSSecretKey (AuthenticationConfig)

AWS secret key credentials.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKey(AuthenticationConfig):
    """AWS secret key credentials."""

    aws_access_key_id: PlainSerializedSecretStr = Field(
        title="AWS Access Key ID",
        description="An AWS access key ID associated with an AWS account or IAM user.",
    )
    aws_secret_access_key: PlainSerializedSecretStr = Field(
        title="AWS Secret Access Key",
    )
AWSSecretKeyConfig (AWSBaseConfig, AWSSecretKey)

AWS secret key authentication configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKeyConfig(AWSBaseConfig, AWSSecretKey):
    """AWS secret key authentication configuration."""
AWSServiceConnector (ServiceConnector)

AWS service connector.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSServiceConnector(ServiceConnector):
    """AWS service connector."""

    config: AWSBaseConfig

    _account_id: Optional[str] = None
    _session_cache: Dict[
        Tuple[str, Optional[str], Optional[str]],
        Tuple[boto3.Session, Optional[datetime.datetime]],
    ] = {}

    @classmethod
    def _get_connector_type(cls) -> ServiceConnectorTypeModel:
        """Get the service connector type specification.

        Returns:
            The service connector type specification.
        """
        return AWS_SERVICE_CONNECTOR_TYPE_SPEC

    @property
    def account_id(self) -> str:
        """Get the AWS account ID.

        Returns:
            The AWS account ID.

        Raises:
            AuthorizationException: If the AWS account ID could not be
                determined.
        """
        if self._account_id is None:
            logger.debug("Getting account ID from AWS...")
            try:
                session, _ = self.get_boto3_session(self.auth_method)
                sts_client = session.client("sts")
                response = sts_client.get_caller_identity()
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to fetch the AWS account ID: {e}"
                ) from e

            self._account_id = response["Account"]

        return self._account_id

    def get_boto3_session(
        self,
        auth_method: str,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
        """Get a boto3 session for the specified resource.

        Args:
            auth_method: The authentication method to use.
            resource_type: The resource type to get a boto3 session for.
            resource_id: The resource ID to get a boto3 session for.

        Returns:
            A boto3 session for the specified resource and its expiration
            timestamp, if applicable.
        """
        # We maintain a cache of all sessions to avoid re-authenticating
        # multiple times for the same resource
        key = (auth_method, resource_type, resource_id)
        if key in self._session_cache:
            session, expires_at = self._session_cache[key]
            if expires_at is None:
                return session, None

            # Refresh expired sessions
            now = datetime.datetime.now(datetime.timezone.utc)
            expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
            # check if the token expires in the near future
            if expires_at > now + datetime.timedelta(
                minutes=BOTO3_SESSION_EXPIRATION_BUFFER
            ):
                return session, expires_at

        logger.debug(
            f"Creating boto3 session for auth method '{auth_method}', "
            f"resource type '{resource_type}' and resource ID "
            f"'{resource_id}'..."
        )
        session, expires_at = self._authenticate(
            auth_method, resource_type, resource_id
        )
        self._session_cache[key] = (session, expires_at)
        return session, expires_at

    def get_ecr_client(self) -> BaseClient:
        """Get an ECR client.

        Raises:
            ValueError: If the service connector is not able to instantiate an
                ECR client.

        Returns:
            An ECR client.
        """
        if self.resource_type and self.resource_type not in {
            AWS_RESOURCE_TYPE,
            DOCKER_REGISTRY_RESOURCE_TYPE,
        }:
            raise ValueError(
                f"Unable to instantiate ECR client for a connector that is "
                f"configured to provide access to a '{self.resource_type}' "
                "resource type."
            )

        session, _ = self.get_boto3_session(
            auth_method=self.auth_method,
            resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
            resource_id=self.config.region,
        )
        return session.client(
            "ecr",
            region_name=self.config.region,
            endpoint_url=self.config.endpoint_url,
        )

    def _get_iam_policy(
        self,
        region_id: str,
        resource_type: Optional[str],
        resource_id: Optional[str] = None,
    ) -> Optional[str]:
        """Get the IAM inline policy to use for the specified resource.

        Args:
            region_id: The AWS region ID to get the IAM inline policy for.
            resource_type: The resource type to get the IAM inline policy for.
            resource_id: The resource ID to get the IAM inline policy for.

        Returns:
            The IAM inline policy to use for the specified resource.
        """
        if resource_type == S3_RESOURCE_TYPE:
            if resource_id:
                bucket = self._parse_s3_resource_id(resource_id)
                resource = [
                    f"arn:aws:s3:::{bucket}",
                    f"arn:aws:s3:::{bucket}/*",
                ]
            else:
                resource = ["arn:aws:s3:::*", "arn:aws:s3:::*/*"]
            policy = {
                "Version": "2012-10-17",
                "Statement": [
                    {
                        "Sid": "AllowS3BucketAccess",
                        "Effect": "Allow",
                        "Action": [
                            "s3:ListBucket",
                            "s3:GetObject",
                            "s3:PutObject",
                            "s3:DeleteObject",
                            "s3:ListAllMyBuckets",
                        ],
                        "Resource": resource,
                    },
                ],
            }
            return json.dumps(policy)
        elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            if resource_id:
                cluster_name = self._parse_eks_resource_id(resource_id)
                resource = [
                    f"arn:aws:eks:{region_id}:*:cluster/{cluster_name}",
                ]
            else:
                resource = [f"arn:aws:eks:{region_id}:*:cluster/*"]
            policy = {
                "Version": "2012-10-17",
                "Statement": [
                    {
                        "Sid": "AllowEKSClusterAccess",
                        "Effect": "Allow",
                        "Action": [
                            "eks:ListClusters",
                            "eks:DescribeCluster",
                        ],
                        "Resource": resource,
                    },
                ],
            }
            return json.dumps(policy)
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            resource = [
                f"arn:aws:ecr:{region_id}:*:repository/*",
                f"arn:aws:ecr:{region_id}:*:repository",
            ]
            policy = {
                "Version": "2012-10-17",
                "Statement": [
                    {
                        "Sid": "AllowECRRepositoryAccess",
                        "Effect": "Allow",
                        "Action": [
                            "ecr:DescribeRegistry",
                            "ecr:DescribeRepositories",
                            "ecr:ListRepositories",
                            "ecr:BatchGetImage",
                            "ecr:DescribeImages",
                            "ecr:BatchCheckLayerAvailability",
                            "ecr:GetDownloadUrlForLayer",
                            "ecr:InitiateLayerUpload",
                            "ecr:UploadLayerPart",
                            "ecr:CompleteLayerUpload",
                            "ecr:PutImage",
                        ],
                        "Resource": resource,
                    },
                    {
                        "Sid": "AllowECRRepositoryGetToken",
                        "Effect": "Allow",
                        "Action": [
                            "ecr:GetAuthorizationToken",
                        ],
                        "Resource": ["*"],
                    },
                ],
            }
            return json.dumps(policy)

        return None

    def _authenticate(
        self,
        auth_method: str,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
        """Authenticate to AWS and return a boto3 session.

        Args:
            auth_method: The authentication method to use.
            resource_type: The resource type to authenticate for.
            resource_id: The resource ID to authenticate for.

        Returns:
            An authenticated boto3 session and the expiration time of the
            temporary credentials if applicable.

        Raises:
            AuthorizationException: If the IAM role authentication method is
                used and the role cannot be assumed.
            NotImplementedError: If the authentication method is not supported.
        """
        cfg = self.config
        policy_kwargs: Dict[str, Any] = {}

        if auth_method == AWSAuthenticationMethods.IMPLICIT:
            self._check_implicit_auth_method_allowed()

            assert isinstance(cfg, AWSImplicitConfig)
            # Create a boto3 session and use the default credentials provider
            session = boto3.Session(
                profile_name=cfg.profile_name, region_name=cfg.region
            )

            if cfg.role_arn:
                # If an IAM role is configured, assume it
                policy = cfg.policy
                if not cfg.policy and not cfg.policy_arns:
                    policy = self._get_iam_policy(
                        region_id=cfg.region,
                        resource_type=resource_type,
                        resource_id=resource_id,
                    )
                if policy:
                    policy_kwargs["Policy"] = policy
                elif cfg.policy_arns:
                    policy_kwargs["PolicyArns"] = cfg.policy_arns

                sts = session.client("sts", region_name=cfg.region)
                session_name = "zenml-connector"
                if self.id:
                    session_name += f"-{self.id}"

                try:
                    response = sts.assume_role(
                        RoleArn=cfg.role_arn,
                        RoleSessionName=session_name,
                        DurationSeconds=self.expiration_seconds,
                        **policy_kwargs,
                    )
                except (ClientError, BotoCoreError) as e:
                    raise AuthorizationException(
                        f"Failed to assume IAM role {cfg.role_arn} "
                        f"using the implicit AWS credentials: {e}"
                    ) from e

                session = boto3.Session(
                    aws_access_key_id=response["Credentials"]["AccessKeyId"],
                    aws_secret_access_key=response["Credentials"][
                        "SecretAccessKey"
                    ],
                    aws_session_token=response["Credentials"]["SessionToken"],
                )
                expiration = response["Credentials"]["Expiration"]
                # Add the UTC timezone to the expiration time
                expiration = expiration.replace(tzinfo=datetime.timezone.utc)
                return session, expiration

            credentials = session.get_credentials()
            if not credentials:
                raise AuthorizationException(
                    "Failed to get AWS credentials from the default provider. "
                    "Please check your AWS configuration or attached IAM role."
                )
            if credentials.token:
                # Temporary credentials were generated. It's not possible to
                # determine the expiration time of the temporary credentials
                # from the boto3 session, so we assume the default IAM role
                # expiration date is used
                expiration_time = datetime.datetime.now(
                    tz=datetime.timezone.utc
                ) + datetime.timedelta(
                    seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
                )
                return session, expiration_time

            return session, None
        elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
            assert isinstance(cfg, AWSSecretKeyConfig)
            # Create a boto3 session using long-term AWS credentials
            session = boto3.Session(
                aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
                aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
                region_name=cfg.region,
            )
            return session, None
        elif auth_method == AWSAuthenticationMethods.STS_TOKEN:
            assert isinstance(cfg, STSTokenConfig)
            # Create a boto3 session using a temporary AWS STS token
            session = boto3.Session(
                aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
                aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
                aws_session_token=cfg.aws_session_token.get_secret_value(),
                region_name=cfg.region,
            )
            return session, self.expires_at
        elif auth_method in [
            AWSAuthenticationMethods.IAM_ROLE,
            AWSAuthenticationMethods.SESSION_TOKEN,
            AWSAuthenticationMethods.FEDERATION_TOKEN,
        ]:
            assert isinstance(cfg, AWSSecretKey)

            # Create a boto3 session
            session = boto3.Session(
                aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
                aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
                region_name=cfg.region,
            )

            sts = session.client("sts", region_name=cfg.region)
            session_name = "zenml-connector"
            if self.id:
                session_name += f"-{self.id}"

            # Next steps are different for each authentication method

            # The IAM role and federation token authentication methods
            # accept a managed IAM policy that restricts/grants permissions.
            # If one isn't explicitly configured, we generate one based on the
            # resource specified by the resource type and ID (if present).
            if auth_method in [
                AWSAuthenticationMethods.IAM_ROLE,
                AWSAuthenticationMethods.FEDERATION_TOKEN,
            ]:
                assert isinstance(cfg, AWSSessionPolicy)
                policy = cfg.policy
                if not cfg.policy and not cfg.policy_arns:
                    policy = self._get_iam_policy(
                        region_id=cfg.region,
                        resource_type=resource_type,
                        resource_id=resource_id,
                    )
                if policy:
                    policy_kwargs["Policy"] = policy
                elif cfg.policy_arns:
                    policy_kwargs["PolicyArns"] = cfg.policy_arns

                if auth_method == AWSAuthenticationMethods.IAM_ROLE:
                    assert isinstance(cfg, IAMRoleAuthenticationConfig)

                    try:
                        response = sts.assume_role(
                            RoleArn=cfg.role_arn,
                            RoleSessionName=session_name,
                            DurationSeconds=self.expiration_seconds,
                            **policy_kwargs,
                        )
                    except (ClientError, BotoCoreError) as e:
                        raise AuthorizationException(
                            f"Failed to assume IAM role {cfg.role_arn} "
                            f"using the AWS credentials configured in the "
                            f"connector: {e}"
                        ) from e

                else:
                    assert isinstance(cfg, FederationTokenAuthenticationConfig)

                    try:
                        response = sts.get_federation_token(
                            Name=session_name[:32],
                            DurationSeconds=self.expiration_seconds,
                            **policy_kwargs,
                        )
                    except (ClientError, BotoCoreError) as e:
                        raise AuthorizationException(
                            "Failed to get federation token "
                            "using the AWS credentials configured in the "
                            f"connector: {e}"
                        ) from e

            else:
                assert isinstance(cfg, SessionTokenAuthenticationConfig)
                try:
                    response = sts.get_session_token(
                        DurationSeconds=self.expiration_seconds,
                    )
                except (ClientError, BotoCoreError) as e:
                    raise AuthorizationException(
                        "Failed to get session token "
                        "using the AWS credentials configured in the "
                        f"connector: {e}"
                    ) from e

            session = boto3.Session(
                aws_access_key_id=response["Credentials"]["AccessKeyId"],
                aws_secret_access_key=response["Credentials"][
                    "SecretAccessKey"
                ],
                aws_session_token=response["Credentials"]["SessionToken"],
            )
            expiration = response["Credentials"]["Expiration"]
            # Add the UTC timezone to the expiration time
            expiration = expiration.replace(tzinfo=datetime.timezone.utc)
            return session, expiration

        raise NotImplementedError(
            f"Authentication method '{auth_method}' is not supported by "
            "the AWS connector."
        )

    @classmethod
    def _get_eks_bearer_token(
        cls,
        session: boto3.Session,
        cluster_id: str,
        region: str,
    ) -> str:
        """Generate a bearer token for authenticating to the EKS API server.

        Based on: https://github.com/kubernetes-sigs/aws-iam-authenticator/blob/master/README.md#api-authorization-from-outside-a-cluster

        Args:
            session: An authenticated boto3 session to use for generating the
                token.
            cluster_id: The name of the EKS cluster.
            region: The AWS region the EKS cluster is in.

        Returns:
            A bearer token for authenticating to the EKS API server.
        """
        STS_TOKEN_EXPIRES_IN = 60

        client = session.client("sts", region_name=region)
        service_id = client.meta.service_model.service_id

        signer = RequestSigner(
            service_id,
            region,
            "sts",
            "v4",
            session.get_credentials(),
            session.events,
        )

        params = {
            "method": "GET",
            "url": f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15",
            "body": {},
            "headers": {"x-k8s-aws-id": cluster_id},
            "context": {},
        }

        signed_url = signer.generate_presigned_url(
            params,
            region_name=region,
            expires_in=STS_TOKEN_EXPIRES_IN,
            operation_name="",
        )

        base64_url = base64.urlsafe_b64encode(
            signed_url.encode("utf-8")
        ).decode("utf-8")

        # remove any base64 encoding padding:
        return "k8s-aws-v1." + re.sub(r"=*", "", base64_url)

    def _parse_s3_resource_id(self, resource_id: str) -> str:
        """Validate and convert an S3 resource ID to an S3 bucket name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The S3 bucket name.

        Raises:
            ValueError: If the provided resource ID is not a valid S3 bucket
                name, ARN or URI.
        """
        # The resource ID could mean different things:
        #
        # - an S3 bucket ARN
        # - an S3 bucket URI
        # - the S3 bucket name
        #
        # We need to extract the bucket name from the provided resource ID
        bucket_name: Optional[str] = None
        if re.match(
            r"^arn:aws:s3:::[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9](/.*)*$",
            resource_id,
        ):
            # The resource ID is an S3 bucket ARN
            bucket_name = resource_id.split(":")[-1].split("/")[0]
        elif re.match(
            r"^s3://[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9](/.*)*$",
            resource_id,
        ):
            # The resource ID is an S3 bucket URI
            bucket_name = resource_id.split("/")[2]
        elif re.match(
            r"^[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9]$",
            resource_id,
        ):
            # The resource ID is the S3 bucket name
            bucket_name = resource_id
        else:
            raise ValueError(
                f"Invalid resource ID for an S3 bucket: {resource_id}. "
                f"Supported formats are:\n"
                f"S3 bucket ARN: arn:aws:s3:::<bucket-name>\n"
                f"S3 bucket URI: s3://<bucket-name>\n"
                f"S3 bucket name: <bucket-name>"
            )

        return bucket_name

    def _parse_ecr_resource_id(
        self,
        resource_id: str,
    ) -> str:
        """Validate and convert an ECR resource ID to an ECR registry ID.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The ECR registry ID (AWS account ID).

        Raises:
            ValueError: If the provided resource ID is not a valid ECR
                repository ARN or URI.
        """
        # The resource ID could mean different things:
        #
        # - an ECR repository ARN
        # - an ECR repository URI
        #
        # We need to extract the region ID and registry ID from
        # the provided resource ID
        config_region_id = self.config.region
        region_id: Optional[str] = None
        if re.match(
            r"^arn:aws:ecr:[a-z0-9-]+:\d{12}:repository(/.+)*$",
            resource_id,
        ):
            # The resource ID is an ECR repository ARN
            registry_id = resource_id.split(":")[4]
            region_id = resource_id.split(":")[3]
        elif re.match(
            r"^(http[s]?://)?\d{12}\.dkr\.ecr\.[a-z0-9-]+\.amazonaws\.com(/.+)*$",
            resource_id,
        ):
            # The resource ID is an ECR repository URI
            registry_id = resource_id.split(".")[0].split("/")[-1]
            region_id = resource_id.split(".")[3]
        else:
            raise ValueError(
                f"Invalid resource ID for a ECR registry: {resource_id}. "
                f"Supported formats are:\n"
                f"ECR repository ARN: arn:aws:ecr:<region>:<account-id>:repository[/<repository-name>]\n"
                f"ECR repository URI: [https://]<account-id>.dkr.ecr.<region>.amazonaws.com[/<repository-name>]"
            )

        # If the connector is configured with a region and the resource ID
        # is an ECR repository ARN or URI that specifies a different region
        # we raise an error
        if region_id and region_id != config_region_id:
            raise ValueError(
                f"The AWS region for the {resource_id} ECR repository region "
                f"'{region_id}' does not match the region configured in "
                f"the connector: '{config_region_id}'."
            )

        return registry_id

    def _parse_eks_resource_id(self, resource_id: str) -> str:
        """Validate and convert an EKS resource ID to an AWS region and EKS cluster name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The EKS cluster name.

        Raises:
            ValueError: If the provided resource ID is not a valid EKS cluster
                name or ARN.
        """
        # The resource ID could mean different things:
        #
        # - an EKS cluster ARN
        # - an EKS cluster ID
        #
        # We need to extract the cluster name and region ID from the
        # provided resource ID
        config_region_id = self.config.region
        cluster_name: Optional[str] = None
        region_id: Optional[str] = None
        if re.match(
            r"^arn:aws:eks:[a-z0-9-]+:\d{12}:cluster/[0-9A-Za-z][A-Za-z0-9\-_]*$",
            resource_id,
        ):
            # The resource ID is an EKS cluster ARN
            cluster_name = resource_id.split("/")[-1]
            region_id = resource_id.split(":")[3]
        elif re.match(
            r"^[0-9A-Za-z][A-Za-z0-9\-_]*$",
            resource_id,
        ):
            # Assume the resource ID is an EKS cluster name
            cluster_name = resource_id
        else:
            raise ValueError(
                f"Invalid resource ID for a EKS cluster: {resource_id}. "
                f"Supported formats are:\n"
                f"EKS cluster ARN: arn:aws:eks:<region>:<account-id>:cluster/<cluster-name>\n"
                f"ECR cluster name: <cluster-name>"
            )

        # If the connector is configured with a region and the resource ID
        # is an EKS registry ARN or URI that specifies a different region
        # we raise an error
        if region_id and region_id != config_region_id:
            raise ValueError(
                f"The AWS region for the {resource_id} EKS cluster "
                f"({region_id}) does not match the region configured in "
                f"the connector ({config_region_id})."
            )

        return cluster_name

    def _canonical_resource_id(
        self, resource_type: str, resource_id: str
    ) -> str:
        """Convert a resource ID to its canonical form.

        Args:
            resource_type: The resource type to canonicalize.
            resource_id: The resource ID to canonicalize.

        Returns:
            The canonical resource ID.
        """
        if resource_type == S3_RESOURCE_TYPE:
            bucket = self._parse_s3_resource_id(resource_id)
            return f"s3://{bucket}"
        elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            cluster_name = self._parse_eks_resource_id(resource_id)
            return cluster_name
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            registry_id = self._parse_ecr_resource_id(
                resource_id,
            )
            return f"{registry_id}.dkr.ecr.{self.config.region}.amazonaws.com"
        else:
            return resource_id

    def _get_default_resource_id(self, resource_type: str) -> str:
        """Get the default resource ID for a resource type.

        Args:
            resource_type: The type of the resource to get a default resource ID
                for. Only called with resource types that do not support
                multiple instances.

        Returns:
            The default resource ID for the resource type.

        Raises:
            RuntimeError: If the ECR registry ID (AWS account ID)
                cannot be retrieved from AWS because the connector is not
                authorized.
        """
        if resource_type == AWS_RESOURCE_TYPE:
            return self.config.region
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            # we need to get the account ID (same as registry ID) from the
            # caller identity AWS service
            account_id = self.account_id

            return f"{account_id}.dkr.ecr.{self.config.region}.amazonaws.com"

        raise RuntimeError(
            f"Default resource ID for '{resource_type}' not available."
        )

    def _connect_to_resource(
        self,
        **kwargs: Any,
    ) -> Any:
        """Authenticate and connect to an AWS resource.

        Initialize and return a session or client object depending on the
        connector configuration:

        - initialize and return a boto3 session if the resource type
        is a generic AWS resource
        - initialize and return a boto3 client for an S3 resource type

        For the Docker and Kubernetes resource types, the connector does not
        support connecting to the resource directly. Instead, the connector
        supports generating a connector client object for the resource type
        in question.

        Args:
            kwargs: Additional implementation specific keyword arguments to pass
                to the session or client constructor.

        Returns:
            A boto3 session for AWS generic resources and a boto3 S3 client for
            S3 resources.

        Raises:
            NotImplementedError: If the connector instance does not support
                directly connecting to the indicated resource type.
        """
        resource_type = self.resource_type
        resource_id = self.resource_id

        assert resource_type is not None
        assert resource_id is not None

        # Regardless of the resource type, we must authenticate to AWS first
        # before we can connect to any AWS resource
        session, _ = self.get_boto3_session(
            self.auth_method,
            resource_type=resource_type,
            resource_id=resource_id,
        )

        if resource_type == S3_RESOURCE_TYPE:
            # Validate that the resource ID is a valid S3 bucket name
            self._parse_s3_resource_id(resource_id)

            # Create an S3 client for the bucket
            client = session.client(
                "s3",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )

            # There is no way to retrieve the credentials from the S3 client
            # but some consumers need them to configure 3rd party services.
            # We therefore store the credentials in the client object so that
            # they can be retrieved later.
            client.credentials = session.get_credentials()
            return client

        if resource_type == AWS_RESOURCE_TYPE:
            return session

        raise NotImplementedError(
            f"Connecting to {resource_type} resources is not directly "
            "supported by the AWS connector. Please call the "
            f"`get_connector_client` method to get a {resource_type} connector "
            "instance for the resource."
        )

    def _configure_local_client(
        self,
        profile_name: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Configure a local client to authenticate and connect to a resource.

        This method uses the connector's configuration to configure a local
        client or SDK installed on the localhost for the indicated resource.

        Args:
            profile_name: The name of the AWS profile to use. If not specified,
                a profile name is generated based on the first 8 digits of the
                connector's UUID in the form 'zenml-<uuid[:8]>'. If a profile
                with the given or generated name already exists, the profile is
                overwritten.
            kwargs: Additional implementation specific keyword arguments to use
                to configure the client.

        Raises:
            NotImplementedError: If the connector instance does not support
                local configuration for the configured resource type or
                authentication method.registry
        """
        resource_type = self.resource_type

        if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
            session, _ = self.get_boto3_session(
                self.auth_method,
                resource_type=resource_type,
                resource_id=self.resource_id,
            )

            # Configure a new AWS SDK profile with the credentials
            # from the session using the aws-profile-manager package

            # Generate a profile name based on the first 8 digits from the
            # connector UUID, if one is not supplied
            aws_profile_name = profile_name or f"zenml-{str(self.id)[:8]}"
            common = Common()
            users_home = common.get_users_home()
            all_profiles = common.get_all_profiles(users_home)

            credentials = session.get_credentials()
            all_profiles[aws_profile_name] = {
                "region": self.config.region,
                "aws_access_key_id": credentials.access_key,
                "aws_secret_access_key": credentials.secret_key,
            }

            if credentials.token:
                all_profiles[aws_profile_name]["aws_session_token"] = (
                    credentials.token
                )

            aws_credentials_path = os.path.join(
                users_home, ".aws", "credentials"
            )

            # Create the file as well as the parent dir if needed.
            dirname = os.path.split(aws_credentials_path)[0]
            if not os.path.isdir(dirname):
                os.makedirs(dirname)
            with os.fdopen(
                os.open(aws_credentials_path, os.O_WRONLY | os.O_CREAT, 0o600),
                "w",
            ):
                pass

            # Write the credentials to the file
            common.rewrite_credentials_file(all_profiles, users_home)

            logger.info(
                f"Configured local AWS SDK profile '{aws_profile_name}'."
            )

            return

        raise NotImplementedError(
            f"Configuring the local client for {resource_type} resources is "
            "not directly supported by the AWS connector. Please call the "
            f"`get_connector_client` method to get a {resource_type} connector "
            "instance for the resource."
        )

    @classmethod
    def _auto_configure(
        cls,
        auth_method: Optional[str] = None,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
        region_name: Optional[str] = None,
        profile_name: Optional[str] = None,
        role_arn: Optional[str] = None,
        **kwargs: Any,
    ) -> "AWSServiceConnector":
        """Auto-configure the connector.

        Instantiate an AWS connector with a configuration extracted from the
        authentication configuration available in the environment (e.g.
        environment variables or local AWS client/SDK configuration files).

        Args:
            auth_method: The particular authentication method to use. If not
                specified, the connector implementation must decide which
                authentication method to use or raise an exception.
            resource_type: The type of resource to configure.
            resource_id: The ID of the resource to configure. The
                implementation may choose to either require or ignore this
                parameter if it does not support or detect an resource type that
                supports multiple instances.
            region_name: The name of the AWS region to use. If not specified,
                the implicit region is used.
            profile_name: The name of the AWS profile to use. If not specified,
                the implicit profile is used.
            role_arn: The ARN of the AWS role to assume. Applicable only if the
                IAM role authentication method is specified or long-term
                credentials are discovered.
            kwargs: Additional implementation specific keyword arguments to use.

        Returns:
            An AWS connector instance configured with authentication credentials
            automatically extracted from the environment.

        Raises:
            NotImplementedError: If the connector implementation does not
                support auto-configuration for the specified authentication
                method.
            ValueError: If the supplied arguments are not valid.
            AuthorizationException: If no AWS credentials can be loaded from
                the environment.
        """
        auth_config: AWSBaseConfig
        expiration_seconds: Optional[int] = None
        expires_at: Optional[datetime.datetime] = None
        if auth_method == AWSAuthenticationMethods.IMPLICIT:
            if region_name is None:
                raise ValueError(
                    "The AWS region name must be specified when using the "
                    "implicit authentication method"
                )
            auth_config = AWSImplicitConfig(
                profile_name=profile_name,
                region=region_name,
            )
        else:
            # Initialize an AWS session with the default configuration loaded
            # from the environment.
            session = boto3.Session(
                profile_name=profile_name, region_name=region_name
            )

            region_name = region_name or session.region_name
            if not region_name:
                raise ValueError(
                    "The AWS region name was not specified and could not "
                    "be determined from the AWS session"
                )
            endpoint_url = session._session.get_config_variable("endpoint_url")

            # Extract the AWS credentials from the session and store them in
            # the connector secrets.
            credentials = session.get_credentials()
            if not credentials:
                raise AuthorizationException(
                    "Could not determine the AWS credentials from the "
                    "environment"
                )
            if credentials.token:
                # The session picked up temporary STS credentials
                if auth_method and auth_method not in [
                    None,
                    AWSAuthenticationMethods.STS_TOKEN,
                    AWSAuthenticationMethods.IAM_ROLE,
                ]:
                    raise NotImplementedError(
                        f"The specified authentication method '{auth_method}' "
                        "could not be used to auto-configure the connector. "
                    )

                if (
                    credentials.method == "assume-role"
                    and auth_method != AWSAuthenticationMethods.STS_TOKEN
                ):
                    # In the special case of IAM role authentication, the
                    # credentials in the boto3 session are the temporary STS
                    # credentials instead of the long-lived credentials, and the
                    # role ARN is not known. We have to dig deeper into the
                    # botocore session internals to retrieve the role ARN and
                    # the original long-lived credentials.

                    botocore_session = session._session
                    profile_config = botocore_session.get_scoped_config()
                    source_profile = profile_config.get("source_profile")
                    role_arn = profile_config.get("role_arn")
                    profile_map = botocore_session._build_profile_map()
                    if not (
                        role_arn
                        and source_profile
                        and source_profile in profile_map
                    ):
                        raise AuthorizationException(
                            "Could not determine the IAM role ARN and source "
                            "profile credentials from the environment"
                        )

                    auth_method = AWSAuthenticationMethods.IAM_ROLE
                    source_profile_config = profile_map[source_profile]
                    auth_config = IAMRoleAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=source_profile_config.get(
                            "aws_access_key_id"
                        ),
                        aws_secret_access_key=source_profile_config.get(
                            "aws_secret_access_key",
                        ),
                        role_arn=role_arn,
                    )
                    expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION

                else:
                    if auth_method == AWSAuthenticationMethods.IAM_ROLE:
                        raise NotImplementedError(
                            f"The specified authentication method "
                            f"'{auth_method}' could not be used to "
                            "auto-configure the connector."
                        )

                    # Temporary credentials were picked up from the local
                    # configuration. It's not possible to determine the
                    # expiration time of the temporary credentials from the
                    # boto3 session, so we assume the default IAM role
                    # expiration period is used
                    expires_at = datetime.datetime.now(
                        tz=datetime.timezone.utc
                    ) + datetime.timedelta(
                        seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
                    )

                    auth_method = AWSAuthenticationMethods.STS_TOKEN
                    auth_config = STSTokenConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                        aws_session_token=credentials.token,
                    )
            else:
                # The session picked up long-lived credentials
                if not auth_method:
                    if role_arn:
                        auth_method = AWSAuthenticationMethods.IAM_ROLE
                    else:
                        # If no authentication method was specified, use the
                        # session token as a default recommended authentication
                        # method to be used with long-lived credentials.
                        auth_method = AWSAuthenticationMethods.SESSION_TOKEN

                region_name = region_name or session.region_name
                if not region_name:
                    raise ValueError(
                        "The AWS region name was not specified and could not "
                        "be determined from the AWS session"
                    )

                if auth_method == AWSAuthenticationMethods.STS_TOKEN:
                    # Generate a session token from the long-lived credentials
                    # and store it in the connector secrets.
                    sts_client = session.client("sts")
                    response = sts_client.get_session_token(
                        DurationSeconds=DEFAULT_STS_TOKEN_EXPIRATION
                    )
                    credentials = response["Credentials"]
                    auth_config = STSTokenConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials["AccessKeyId"],
                        aws_secret_access_key=credentials["SecretAccessKey"],
                        aws_session_token=credentials["SessionToken"],
                    )
                    expires_at = datetime.datetime.now(
                        tz=datetime.timezone.utc
                    ) + datetime.timedelta(
                        seconds=DEFAULT_STS_TOKEN_EXPIRATION
                    )

                elif auth_method == AWSAuthenticationMethods.IAM_ROLE:
                    if not role_arn:
                        raise ValueError(
                            "The ARN of the AWS role to assume must be "
                            "specified when using the IAM role authentication "
                            "method."
                        )
                    auth_config = IAMRoleAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                        role_arn=role_arn,
                    )
                    expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION

                elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
                    auth_config = AWSSecretKeyConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )

                elif auth_method == AWSAuthenticationMethods.SESSION_TOKEN:
                    auth_config = SessionTokenAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )
                    expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION

                else:  # auth_method is AWSAuthenticationMethods.FEDERATION_TOKEN
                    auth_config = FederationTokenAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )
                    expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION

        return cls(
            auth_method=auth_method,
            resource_type=resource_type,
            resource_id=resource_id
            if resource_type not in [AWS_RESOURCE_TYPE, None]
            else None,
            expiration_seconds=expiration_seconds,
            expires_at=expires_at,
            config=auth_config,
        )

    def _verify(
        self,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> List[str]:
        """Verify and list all the resources that the connector can access.

        Args:
            resource_type: The type of the resource to verify. If omitted and
                if the connector supports multiple resource types, the
                implementation must verify that it can authenticate and connect
                to any and all of the supported resource types.
            resource_id: The ID of the resource to connect to. Omitted if a
                resource type is not specified. It has the same value as the
                default resource ID if the supplied resource type doesn't
                support multiple instances. If the supplied resource type does
                allows multiple instances, this parameter may still be omitted
                to fetch a list of resource IDs identifying all the resources
                of the indicated type that the connector can access.

        Returns:
            The list of resources IDs in canonical format identifying the
            resources that the connector can access. This list is empty only
            if the resource type is not specified (i.e. for multi-type
            connectors).

        Raises:
            AuthorizationException: If the connector cannot authenticate or
                access the specified resource.
        """
        # If the resource type is not specified, treat this the
        # same as a generic AWS connector.
        session, _ = self.get_boto3_session(
            self.auth_method,
            resource_type=resource_type or AWS_RESOURCE_TYPE,
            resource_id=resource_id,
        )

        # Verify that the AWS account is accessible
        assert isinstance(session, boto3.Session)
        sts_client = session.client("sts")
        try:
            sts_client.get_caller_identity()
        except (ClientError, BotoCoreError) as err:
            msg = f"failed to verify AWS account access: {err}"
            logger.debug(msg)
            raise AuthorizationException(msg) from err

        if not resource_type:
            return []

        if resource_type == AWS_RESOURCE_TYPE:
            assert resource_id is not None
            return [resource_id]

        if resource_type == S3_RESOURCE_TYPE:
            s3_client = session.client(
                "s3",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )
            if not resource_id:
                # List all S3 buckets
                try:
                    response = s3_client.list_buckets()
                except (ClientError, BotoCoreError) as e:
                    msg = f"failed to list S3 buckets: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return [
                    f"s3://{bucket['Name']}" for bucket in response["Buckets"]
                ]
            else:
                # Check if the specified S3 bucket exists
                bucket_name = self._parse_s3_resource_id(resource_id)
                try:
                    s3_client.head_bucket(Bucket=bucket_name)
                    return [resource_id]
                except (ClientError, BotoCoreError) as e:
                    msg = f"failed to fetch S3 bucket {bucket_name}: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

        if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            assert resource_id is not None

            ecr_client = session.client(
                "ecr",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )
            # List all ECR repositories
            try:
                repositories = ecr_client.describe_repositories()
            except (ClientError, BotoCoreError) as e:
                msg = f"failed to list ECR repositories: {e}"
                logger.error(msg)
                raise AuthorizationException(msg) from e

            if len(repositories["repositories"]) == 0:
                raise AuthorizationException(
                    "the AWS connector does not have access to any ECR "
                    "repositories. Please adjust the AWS permissions "
                    "associated with the authentication credentials to "
                    "include access to at least one ECR repository."
                )

            return [resource_id]

        if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            eks_client = session.client(
                "eks",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )
            if not resource_id:
                # List all EKS clusters
                try:
                    clusters = eks_client.list_clusters()
                except (ClientError, BotoCoreError) as e:
                    msg = f"Failed to list EKS clusters: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return cast(List[str], clusters["clusters"])
            else:
                # Check if the specified EKS cluster exists
                cluster_name = self._parse_eks_resource_id(resource_id)
                try:
                    clusters = eks_client.describe_cluster(name=cluster_name)
                except (ClientError, BotoCoreError) as e:
                    msg = f"Failed to fetch EKS cluster {cluster_name}: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return [resource_id]

        return []

    def _get_connector_client(
        self,
        resource_type: str,
        resource_id: str,
    ) -> "ServiceConnector":
        """Get a connector instance that can be used to connect to a resource.

        This method generates a client-side connector instance that can be used
        to connect to a resource of the given type. The client-side connector
        is configured with temporary AWS credentials extracted from the
        current connector and, depending on resource type, it may also be
        of a different connector type:

        - a Kubernetes connector for Kubernetes clusters
        - a Docker connector for Docker registries

        Args:
            resource_type: The type of the resources to connect to.
            resource_id: The ID of a particular resource to connect to.

        Returns:
            An AWS, Kubernetes or Docker connector instance that can be used to
            connect to the specified resource.

        Raises:
            AuthorizationException: If authentication failed.
            ValueError: If the resource type is not supported.
            RuntimeError: If the Kubernetes connector is not installed and the
                resource type is Kubernetes.
        """
        connector_name = ""
        if self.name:
            connector_name = self.name
        if resource_id:
            connector_name += f" ({resource_type} | {resource_id} client)"
        else:
            connector_name += f" ({resource_type} client)"

        logger.debug(f"Getting connector client for {connector_name}")

        if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
            auth_method = self.auth_method
            if self.auth_method in [
                AWSAuthenticationMethods.SECRET_KEY,
                AWSAuthenticationMethods.STS_TOKEN,
            ]:
                if (
                    self.resource_type == resource_type
                    and self.resource_id == resource_id
                ):
                    # If the requested type and resource ID are the same as
                    # those configured, we can return the current connector
                    # instance because it's fully formed and ready to use
                    # to connect to the specified resource
                    return self

                # The secret key and STS token authentication methods do not
                # involve generating temporary credentials, so we can just
                # use the current connector configuration
                config = self.config
                expires_at = self.expires_at
            else:
                # Get an authenticated boto3 session
                session, expires_at = self.get_boto3_session(
                    self.auth_method,
                    resource_type=resource_type,
                    resource_id=resource_id,
                )
                assert isinstance(session, boto3.Session)
                credentials = session.get_credentials()

                if (
                    self.auth_method == AWSAuthenticationMethods.IMPLICIT
                    and credentials.token is None
                ):
                    # The implicit authentication method may involve picking up
                    # long-lived credentials from the environment
                    auth_method = AWSAuthenticationMethods.SECRET_KEY
                    config = AWSSecretKeyConfig(
                        region=self.config.region,
                        endpoint_url=self.config.endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )
                else:
                    assert credentials.token is not None

                    # Use the temporary credentials extracted from the boto3
                    # session
                    auth_method = AWSAuthenticationMethods.STS_TOKEN
                    config = STSTokenConfig(
                        region=self.config.region,
                        endpoint_url=self.config.endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                        aws_session_token=credentials.token,
                    )

            # Create a client-side AWS connector instance that is fully formed
            # and ready to use to connect to the specified resource (i.e. has
            # all the necessary configuration and credentials, a resource type
            # and a resource ID where applicable)
            return AWSServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
                config=config,
                expires_at=expires_at,
            )

        if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            assert resource_id is not None

            # Get an authenticated boto3 session
            session, expires_at = self.get_boto3_session(
                self.auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
            )
            assert isinstance(session, boto3.Session)

            registry_id = self._parse_ecr_resource_id(resource_id)

            ecr_client = session.client(
                "ecr",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )

            assert isinstance(ecr_client, BaseClient)
            assert registry_id is not None

            try:
                auth_token = ecr_client.get_authorization_token(
                    registryIds=[
                        registry_id,
                    ]
                )
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to get authorization token from ECR: {e}"
                ) from e

            token = auth_token["authorizationData"][0]["authorizationToken"]
            endpoint = auth_token["authorizationData"][0]["proxyEndpoint"]
            # The token is base64 encoded and has the format
            # "username:password"
            username, token = (
                base64.b64decode(token).decode("utf-8").split(":")
            )

            # Create a client-side Docker connector instance with the temporary
            # Docker credentials
            return DockerServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=DockerAuthenticationMethods.PASSWORD,
                resource_type=resource_type,
                config=DockerConfiguration(
                    username=username,
                    password=token,
                    registry=endpoint,
                ),
                expires_at=expires_at,
            )

        if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            assert resource_id is not None

            # Get an authenticated boto3 session
            session, _ = self.get_boto3_session(
                self.auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
            )
            assert isinstance(session, boto3.Session)

            cluster_name = self._parse_eks_resource_id(resource_id)

            # Get a boto3 EKS client
            eks_client = session.client(
                "eks",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )

            assert isinstance(eks_client, BaseClient)

            try:
                cluster = eks_client.describe_cluster(name=cluster_name)
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to get EKS cluster {cluster_name}: {e}"
                ) from e

            try:
                user_token = self._get_eks_bearer_token(
                    session=session,
                    cluster_id=cluster_name,
                    region=self.config.region,
                )
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to get EKS bearer token: {e}"
                ) from e

            # Kubernetes authentication tokens issued by AWS EKS have a fixed
            # expiration time of 15 minutes
            # source: https://aws.github.io/aws-eks-best-practices/security/docs/iam/#controlling-access-to-eks-clusters
            expires_at = datetime.datetime.now(
                tz=datetime.timezone.utc
            ) + datetime.timedelta(minutes=EKS_KUBE_API_TOKEN_EXPIRATION)

            # get cluster details
            cluster_arn = cluster["cluster"]["arn"]
            cluster_ca_cert = cluster["cluster"]["certificateAuthority"][
                "data"
            ]
            cluster_server = cluster["cluster"]["endpoint"]

            # Create a client-side Kubernetes connector instance with the
            # temporary Kubernetes credentials
            try:
                # Import libraries only when needed
                from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import (
                    KubernetesAuthenticationMethods,
                    KubernetesServiceConnector,
                    KubernetesTokenConfig,
                )
            except ImportError as e:
                raise RuntimeError(
                    f"The Kubernetes Service Connector functionality could not "
                    f"be used due to missing dependencies: {e}"
                )
            return KubernetesServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=KubernetesAuthenticationMethods.TOKEN,
                resource_type=resource_type,
                config=KubernetesTokenConfig(
                    cluster_name=cluster_arn,
                    certificate_authority=cluster_ca_cert,
                    server=cluster_server,
                    token=user_token,
                ),
                expires_at=expires_at,
            )

        raise ValueError(f"Unsupported resource type: {resource_type}")
account_id: str property readonly

Get the AWS account ID.

Returns:

Type Description
str

The AWS account ID.

Exceptions:

Type Description
AuthorizationException

If the AWS account ID could not be determined.

get_boto3_session(self, auth_method, resource_type=None, resource_id=None)

Get a boto3 session for the specified resource.

Parameters:

Name Type Description Default
auth_method str

The authentication method to use.

required
resource_type Optional[str]

The resource type to get a boto3 session for.

None
resource_id Optional[str]

The resource ID to get a boto3 session for.

None

Returns:

Type Description
Tuple[boto3.Session, Optional[datetime.datetime]]

A boto3 session for the specified resource and its expiration timestamp, if applicable.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def get_boto3_session(
    self,
    auth_method: str,
    resource_type: Optional[str] = None,
    resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
    """Get a boto3 session for the specified resource.

    Args:
        auth_method: The authentication method to use.
        resource_type: The resource type to get a boto3 session for.
        resource_id: The resource ID to get a boto3 session for.

    Returns:
        A boto3 session for the specified resource and its expiration
        timestamp, if applicable.
    """
    # We maintain a cache of all sessions to avoid re-authenticating
    # multiple times for the same resource
    key = (auth_method, resource_type, resource_id)
    if key in self._session_cache:
        session, expires_at = self._session_cache[key]
        if expires_at is None:
            return session, None

        # Refresh expired sessions
        now = datetime.datetime.now(datetime.timezone.utc)
        expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
        # check if the token expires in the near future
        if expires_at > now + datetime.timedelta(
            minutes=BOTO3_SESSION_EXPIRATION_BUFFER
        ):
            return session, expires_at

    logger.debug(
        f"Creating boto3 session for auth method '{auth_method}', "
        f"resource type '{resource_type}' and resource ID "
        f"'{resource_id}'..."
    )
    session, expires_at = self._authenticate(
        auth_method, resource_type, resource_id
    )
    self._session_cache[key] = (session, expires_at)
    return session, expires_at
get_ecr_client(self)

Get an ECR client.

Exceptions:

Type Description
ValueError

If the service connector is not able to instantiate an ECR client.

Returns:

Type Description
botocore.client.BaseClient

An ECR client.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def get_ecr_client(self) -> BaseClient:
    """Get an ECR client.

    Raises:
        ValueError: If the service connector is not able to instantiate an
            ECR client.

    Returns:
        An ECR client.
    """
    if self.resource_type and self.resource_type not in {
        AWS_RESOURCE_TYPE,
        DOCKER_REGISTRY_RESOURCE_TYPE,
    }:
        raise ValueError(
            f"Unable to instantiate ECR client for a connector that is "
            f"configured to provide access to a '{self.resource_type}' "
            "resource type."
        )

    session, _ = self.get_boto3_session(
        auth_method=self.auth_method,
        resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
        resource_id=self.config.region,
    )
    return session.client(
        "ecr",
        region_name=self.config.region,
        endpoint_url=self.config.endpoint_url,
    )
model_post_init(/, self, context)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Parameters:

Name Type Description Default
self BaseModel

The BaseModel instance.

required
context Any

The context.

required
Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
    """This function is meant to behave like a BaseModel method to initialise private attributes.

    It takes context as an argument since that's what pydantic-core passes when calling it.

    Args:
        self: The BaseModel instance.
        context: The context.
    """
    if getattr(self, '__pydantic_private__', None) is None:
        pydantic_private = {}
        for name, private_attr in self.__private_attributes__.items():
            default = private_attr.get_default()
            if default is not PydanticUndefined:
                pydantic_private[name] = default
        object_setattr(self, '__pydantic_private__', pydantic_private)
AWSSessionPolicy (AuthenticationConfig)

AWS session IAM policy configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSessionPolicy(AuthenticationConfig):
    """AWS session IAM policy configuration."""

    policy_arns: Optional[List[str]] = Field(
        default=None,
        title="ARNs of the IAM managed policies that you want to use as a "
        "managed session policy. The policies must exist in the same account "
        "as the IAM user that is requesting temporary credentials.",
    )
    policy: Optional[str] = Field(
        default=None,
        title="An IAM policy in JSON format that you want to use as an inline "
        "session policy",
    )
FederationTokenAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy)

AWS federation token authentication config.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class FederationTokenAuthenticationConfig(
    AWSSecretKeyConfig, AWSSessionPolicy
):
    """AWS federation token authentication config."""
IAMRoleAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy)

AWS IAM authentication config.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class IAMRoleAuthenticationConfig(AWSSecretKeyConfig, AWSSessionPolicy):
    """AWS IAM authentication config."""

    role_arn: str = Field(
        title="AWS IAM Role ARN",
    )
STSToken (AWSSecretKey)

AWS STS token.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSToken(AWSSecretKey):
    """AWS STS token."""

    aws_session_token: PlainSerializedSecretStr = Field(
        title="AWS Session Token",
    )
STSTokenConfig (AWSBaseConfig, STSToken)

AWS STS token authentication configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSTokenConfig(AWSBaseConfig, STSToken):
    """AWS STS token authentication configuration."""
SessionTokenAuthenticationConfig (AWSSecretKeyConfig)

AWS session token authentication config.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class SessionTokenAuthenticationConfig(AWSSecretKeyConfig):
    """AWS session token authentication config."""

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator)

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.
    """

    @property
    def config(self) -> SagemakerStepOperatorConfig:
        """Returns the `SagemakerStepOperatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerStepOperatorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the SageMaker step operator.

        Returns:
            The settings class.
        """
        return SagemakerStepOperatorSettings

    @property
    def entrypoint_config_class(
        self,
    ) -> Type[StepOperatorEntrypointConfiguration]:
        """Returns the entrypoint configuration class for this step operator.

        Returns:
            The entrypoint configuration class for this step operator.
        """
        return SagemakerEntrypointConfiguration

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        Returns:
            A validator that checks that the stack contains a remote container
            registry and a remote artifact store.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            if stack.artifact_store.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to write files into the artifact store, but the "
                    f"artifact store `{stack.artifact_store.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote artifact store when using the SageMaker "
                    "step operator."
                )

            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to push/pull Docker images, but the "
                    f"container registry `{container_registry.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote container registry when using the "
                    "SageMaker step operator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_docker_builds(
        self, deployment: "PipelineDeploymentBase"
    ) -> List["BuildConfiguration"]:
        """Gets the Docker builds required for the component.

        Args:
            deployment: The pipeline deployment for which to get the builds.

        Returns:
            The required Docker builds.
        """
        builds = []
        for step_name, step in deployment.step_configurations.items():
            if step.config.step_operator == self.name:
                build = BuildConfiguration(
                    key=SAGEMAKER_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                    entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
                )
                builds.append(build)

        return builds

    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
        environment: Dict[str, str],
    ) -> None:
        """Launches a step on SageMaker.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
            environment: Environment variables to set in the step operator
                environment.

        Raises:
            RuntimeError: If the connector returns an object that is not a
                `boto3.Session`.
        """
        if not info.config.resource_settings.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the SageMaker step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different instance type like this: "
                "`zenml step-operator update %s "
                "--instance_type=<INSTANCE_TYPE>`",
                self.name,
            )

        # Sagemaker does not allow environment variables longer than 512
        # characters to be passed to Estimator steps. If an environment variable
        # is longer than 512 characters, we split it into multiple environment
        # variables (chunks) and re-construct it on the other side using the
        # custom entrypoint configuration.
        split_environment_variables(
            env=environment,
            size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
        )

        image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
        environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

        settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

        # Get and default fill SageMaker estimator arguments for full ZenML support
        estimator_args = settings.estimator_args

        # Get authenticated session
        # Option 1: Service connector
        boto_session: boto3.Session
        if connector := self.get_connector():
            boto_session = connector.connect()
            if not isinstance(boto_session, boto3.Session):
                raise RuntimeError(
                    f"Expected to receive a `boto3.Session` object from the "
                    f"linked connector, but got type `{type(boto_session)}`."
                )
        # Option 2: Implicit configuration
        else:
            boto_session = boto3.Session()

        session = sagemaker.Session(
            boto_session=boto_session, default_bucket=self.config.bucket
        )

        estimator_args.setdefault(
            "instance_type", settings.instance_type or "ml.m5.large"
        )

        estimator_args["environment"] = environment
        estimator_args["instance_count"] = 1
        estimator_args["sagemaker_session"] = session

        # Create Estimator
        estimator = sagemaker.estimator.Estimator(
            image_name, self.config.role, **estimator_args
        )

        # SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
        step_name = Client().get_run_step(info.step_run_id).name
        training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
        suffix = random_str(4)
        unique_training_job_name = f"{training_job_name}-{suffix}"

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_training_job_name = unique_training_job_name.replace(
            "_", "-"
        )

        # Construct training input object, if necessary
        inputs = None

        if isinstance(settings.input_data_s3_uri, str):
            inputs = sagemaker.inputs.TrainingInput(
                s3_data=settings.input_data_s3_uri
            )
        elif isinstance(settings.input_data_s3_uri, dict):
            inputs = {}
            for channel, s3_uri in settings.input_data_s3_uri.items():
                inputs[channel] = sagemaker.inputs.TrainingInput(
                    s3_data=s3_uri
                )

        experiment_config = {}
        if settings.experiment_name:
            experiment_config = {
                "ExperimentName": settings.experiment_name,
                "TrialName": sanitized_training_job_name,
            }
        info.force_write_logs()
        estimator.fit(
            wait=True,
            inputs=inputs,
            experiment_config=experiment_config,
            job_name=sanitized_training_job_name,
        )
config: SagemakerStepOperatorConfig property readonly

Returns the SagemakerStepOperatorConfig config.

Returns:

Type Description
SagemakerStepOperatorConfig

The configuration.

entrypoint_config_class: Type[zenml.step_operators.step_operator_entrypoint_configuration.StepOperatorEntrypointConfiguration] property readonly

Returns the entrypoint configuration class for this step operator.

Returns:

Type Description
Type[zenml.step_operators.step_operator_entrypoint_configuration.StepOperatorEntrypointConfiguration]

The entrypoint configuration class for this step operator.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the SageMaker step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a remote container registry and a remote artifact store.

get_docker_builds(self, deployment)

Gets the Docker builds required for the component.

Parameters:

Name Type Description Default
deployment PipelineDeploymentBase

The pipeline deployment for which to get the builds.

required

Returns:

Type Description
List[BuildConfiguration]

The required Docker builds.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def get_docker_builds(
    self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
    """Gets the Docker builds required for the component.

    Args:
        deployment: The pipeline deployment for which to get the builds.

    Returns:
        The required Docker builds.
    """
    builds = []
    for step_name, step in deployment.step_configurations.items():
        if step.config.step_operator == self.name:
            build = BuildConfiguration(
                key=SAGEMAKER_DOCKER_IMAGE_KEY,
                settings=step.config.docker_settings,
                step_name=step_name,
                entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
            )
            builds.append(build)

    return builds
launch(self, info, entrypoint_command, environment)

Launches a step on SageMaker.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
environment Dict[str, str]

Environment variables to set in the step operator environment.

required

Exceptions:

Type Description
RuntimeError

If the connector returns an object that is not a boto3.Session.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Launches a step on SageMaker.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
        environment: Environment variables to set in the step operator
            environment.

    Raises:
        RuntimeError: If the connector returns an object that is not a
            `boto3.Session`.
    """
    if not info.config.resource_settings.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    # Sagemaker does not allow environment variables longer than 512
    # characters to be passed to Estimator steps. If an environment variable
    # is longer than 512 characters, we split it into multiple environment
    # variables (chunks) and re-construct it on the other side using the
    # custom entrypoint configuration.
    split_environment_variables(
        env=environment,
        size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
    )

    image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
    environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

    settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

    # Get and default fill SageMaker estimator arguments for full ZenML support
    estimator_args = settings.estimator_args

    # Get authenticated session
    # Option 1: Service connector
    boto_session: boto3.Session
    if connector := self.get_connector():
        boto_session = connector.connect()
        if not isinstance(boto_session, boto3.Session):
            raise RuntimeError(
                f"Expected to receive a `boto3.Session` object from the "
                f"linked connector, but got type `{type(boto_session)}`."
            )
    # Option 2: Implicit configuration
    else:
        boto_session = boto3.Session()

    session = sagemaker.Session(
        boto_session=boto_session, default_bucket=self.config.bucket
    )

    estimator_args.setdefault(
        "instance_type", settings.instance_type or "ml.m5.large"
    )

    estimator_args["environment"] = environment
    estimator_args["instance_count"] = 1
    estimator_args["sagemaker_session"] = session

    # Create Estimator
    estimator = sagemaker.estimator.Estimator(
        image_name, self.config.role, **estimator_args
    )

    # SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
    step_name = Client().get_run_step(info.step_run_id).name
    training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
    suffix = random_str(4)
    unique_training_job_name = f"{training_job_name}-{suffix}"

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_training_job_name = unique_training_job_name.replace(
        "_", "-"
    )

    # Construct training input object, if necessary
    inputs = None

    if isinstance(settings.input_data_s3_uri, str):
        inputs = sagemaker.inputs.TrainingInput(
            s3_data=settings.input_data_s3_uri
        )
    elif isinstance(settings.input_data_s3_uri, dict):
        inputs = {}
        for channel, s3_uri in settings.input_data_s3_uri.items():
            inputs[channel] = sagemaker.inputs.TrainingInput(
                s3_data=s3_uri
            )

    experiment_config = {}
    if settings.experiment_name:
        experiment_config = {
            "ExperimentName": settings.experiment_name,
            "TrialName": sanitized_training_job_name,
        }
    info.force_write_logs()
    estimator.fit(
        wait=True,
        inputs=inputs,
        experiment_config=experiment_config,
        job_name=sanitized_training_job_name,
    )

sagemaker_step_operator_entrypoint_config

Entrypoint configuration for ZenML Sagemaker step operator.

SagemakerEntrypointConfiguration (StepOperatorEntrypointConfiguration)

Entrypoint configuration for ZenML Sagemaker step operator.

The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 512 characters allowed for Sagemaker Estimator steps from their individual components.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py
class SagemakerEntrypointConfiguration(StepOperatorEntrypointConfiguration):
    """Entrypoint configuration for ZenML Sagemaker step operator.

    The only purpose of this entrypoint configuration is to reconstruct the
    environment variables that exceed the maximum length of 512 characters
    allowed for Sagemaker Estimator steps from their individual components.
    """

    def run(self) -> None:
        """Runs the step."""
        # Reconstruct the environment variables that exceed the maximum length
        # of 512 characters from their individual chunks
        reconstruct_environment_variables()

        # Run the step
        super().run()
run(self)

Runs the step.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py
def run(self) -> None:
    """Runs the step."""
    # Reconstruct the environment variables that exceed the maximum length
    # of 512 characters from their individual chunks
    reconstruct_environment_variables()

    # Run the step
    super().run()