Skip to content

Aws

zenml.integrations.aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = [
        "sagemaker==2.117.0",
    ]

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.aws.flavors import (
            AWSContainerRegistryFlavor,
            AWSSecretsManagerFlavor,
            SagemakerOrchestratorFlavor,
            SagemakerStepOperatorFlavor,
        )

        return [
            AWSSecretsManagerFlavor,
            AWSContainerRegistryFlavor,
            SagemakerStepOperatorFlavor,
            SagemakerOrchestratorFlavor,
        ]

flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.aws.flavors import (
        AWSContainerRegistryFlavor,
        AWSSecretsManagerFlavor,
        SagemakerOrchestratorFlavor,
        SagemakerStepOperatorFlavor,
    )

    return [
        AWSSecretsManagerFlavor,
        AWSContainerRegistryFlavor,
        SagemakerStepOperatorFlavor,
        SagemakerOrchestratorFlavor,
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry)

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    @property
    def config(self) -> AWSContainerRegistryConfig:
        """Returns the `AWSContainerRegistryConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSContainerRegistryConfig, self._config)

    def _get_region(self) -> str:
        """Parses the AWS region from the registry URI.

        Raises:
            RuntimeError: If the region parsing fails due to an invalid URI.

        Returns:
            The region string.
        """
        match = re.fullmatch(
            r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
        )
        if not match:
            raise RuntimeError(
                f"Unable to parse region from ECR URI {self.config.uri}."
            )

        return match.group(1)

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        response = boto3.client(
            "ecr", region_name=self._get_region()
        ).describe_repositories()
        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(
            image_name.startswith(f"{uri}:") for uri in repo_uris
        )
        if not repo_exists:
            match = re.search(f"{self.config.uri}/(.*):.*", image_name)
            if not match:
                raise ValueError(f"Invalid docker image name '{image_name}'.")

            repo_name = match.group(1)
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for example run a pipeline "
            "using our Kubeflow orchestrator, ZenML will automatically build a "
            f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create the "
            f"repository `zenml-kubeflow` inside your amazon registry."
        )
config: AWSContainerRegistryConfig property readonly

Returns the AWSContainerRegistryConfig config.

Returns:

Type Description
AWSContainerRegistryConfig

The configuration.

post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    response = boto3.client(
        "ecr", region_name=self._get_region()
    ).describe_repositories()
    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(
        image_name.startswith(f"{uri}:") for uri in repo_uris
    )
    if not repo_exists:
        match = re.search(f"{self.config.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")

        repo_name = match.group(1)
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )

flavors special

AWS integration flavors.

aws_container_registry_flavor

AWS container registry flavor.

AWSContainerRegistryConfig (BaseContainerRegistryConfig) pydantic-model

Configuration for AWS Container Registry.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
    """Configuration for AWS Container Registry."""

    @validator("uri")
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)

AWS Container Registry flavor.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
    """AWS Container Registry flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_CONTAINER_REGISTRY_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png"

    @property
    def config_class(self) -> Type[AWSContainerRegistryConfig]:
        """Config class for this flavor.

        Returns:
            The config class.
        """
        return AWSContainerRegistryConfig

    @property
    def implementation_class(self) -> Type["AWSContainerRegistry"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.container_registries import (
            AWSContainerRegistry,
        )

        return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSContainerRegistry] property readonly

Implementation class.

Returns:

Type Description
Type[AWSContainerRegistry]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

aws_secrets_manager_flavor

AWS secrets manager flavor.

AWSSecretsManagerConfig (BaseSecretsManagerConfig) pydantic-model

Configuration for the AWS Secrets Manager.

Attributes:

Name Type Description
region_name str

The region name of the AWS Secrets Manager.

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
    """Configuration for the AWS Secrets Manager.

    Attributes:
        region_name: The region name of the AWS Secrets Manager.
    """

    SUPPORTS_SCOPING: ClassVar[bool] = True

    region_name: str

    @classmethod
    def _validate_scope(
        cls,
        scope: "SecretsManagerScope",
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)

Class for the AWSSecretsManagerFlavor.

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
    """Class for the `AWSSecretsManagerFlavor`."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            Name of the flavor.
        """
        return AWS_SECRET_MANAGER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/secrets_managers/aws.png"

    @property
    def config_class(self) -> Type[AWSSecretsManagerConfig]:
        """Config class for this flavor.

        Returns:
            Config class for this flavor.
        """
        return AWSSecretsManagerConfig

    @property
    def implementation_class(self) -> Type["AWSSecretsManager"]:
        """Implementation class.

        Returns:
            Implementation class.
        """
        from zenml.integrations.aws.secrets_managers import AWSSecretsManager

        return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]

Config class for this flavor.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSSecretsManager] property readonly

Implementation class.

Returns:

Type Description
Type[AWSSecretsManager]

Implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

Name of the flavor.

validate_aws_secret_name_or_namespace(name)

Validate a secret name or namespace.

AWS secret names must contain only alphanumeric characters and the characters /_+=.@-. The / character is only used internally to delimit scopes.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
    """Validate a secret name or namespace.

    AWS secret names must contain only alphanumeric characters and the
    characters /_+=.@-. The `/` character is only used internally to delimit
    scopes.

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the characters _+=.@-."
        )

sagemaker_orchestrator_flavor

Amazon SageMaker orchestrator flavor.

SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings) pydantic-model

Config for the Sagemaker orchestrator.

Attributes:

Name Type Description
synchronous bool

Whether to run the processing job synchronously or asynchronously. Defaults to False.

execution_role str

The IAM role to use for the pipeline.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig(  # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
    BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
    """Config for the Sagemaker orchestrator.

    Attributes:
        synchronous: Whether to run the processing job synchronously or
            asynchronously. Defaults to False.
        execution_role: The IAM role to use for the pipeline.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format:
            "sagemaker-{region}-{aws-account-id}".
    """

    synchronous: bool = False
    execution_role: str
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)

Flavor for the Sagemaker orchestrator.

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
    """Flavor for the Sagemaker orchestrator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png"

    @property
    def config_class(self) -> Type[SagemakerOrchestratorConfig]:
        """Returns SagemakerOrchestratorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerOrchestratorConfig

    @property
    def implementation_class(self) -> Type["SagemakerOrchestrator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.orchestrators import SagemakerOrchestrator

        return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] property readonly

Returns SagemakerOrchestratorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerOrchestrator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerOrchestrator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

SagemakerOrchestratorSettings (BaseSettings) pydantic-model

Settings for the Sagemaker orchestrator.

Attributes:

Name Type Description
instance_type str

The instance type to use for the processing job.

processor_role Optional[str]

The IAM role to use for the step execution on a Processor.

volume_size_in_gb int

The size of the EBS volume to use for the processing job.

max_runtime_in_seconds int

The maximum runtime in seconds for the processing job.

processor_tags Dict[str, str]

Tags to apply to the Processor assigned to the step.

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
    """Settings for the Sagemaker orchestrator.

    Attributes:
        instance_type: The instance type to use for the processing job.
        processor_role: The IAM role to use for the step execution on a Processor.
        volume_size_in_gb: The size of the EBS volume to use for the processing
            job.
        max_runtime_in_seconds: The maximum runtime in seconds for the
            processing job.
        processor_tags: Tags to apply to the Processor assigned to the step.
    """

    instance_type: str = "ml.t3.medium"
    processor_role: Optional[str] = None
    volume_size_in_gb: int = 30
    max_runtime_in_seconds: int = 86400
    processor_tags: Dict[str, str] = {}

sagemaker_step_operator_flavor

Amazon SageMaker step operator flavor.

SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings) pydantic-model

Config for the Sagemaker step operator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig(  # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
    BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
    """Config for the Sagemaker step operator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
    """

    role: str
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)

Flavor for the Sagemaker step operator.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
    """Flavor for the Sagemaker step operator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png"

    @property
    def config_class(self) -> Type[SagemakerStepOperatorConfig]:
        """Returns SagemakerStepOperatorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerStepOperatorConfig

    @property
    def implementation_class(self) -> Type["SagemakerStepOperator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.step_operators import SagemakerStepOperator

        return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] property readonly

Returns SagemakerStepOperatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerStepOperator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerStepOperator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

SagemakerStepOperatorSettings (BaseSettings) pydantic-model

Settings for the Sagemaker step operator.

Attributes:

Name Type Description
instance_type Optional[str]

The type of the compute instance where jobs will run. Check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
    """Settings for the Sagemaker step operator.

    Attributes:
        instance_type: The type of the compute instance where jobs will run.
            Check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
            for a list of available instance types.
        experiment_name: The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
    """

    instance_type: Optional[str] = None
    experiment_name: Optional[str] = None

orchestrators special

AWS Sagemaker orchestrator.

sagemaker_orchestrator

Implementation of the SageMaker orchestrator.

SagemakerOrchestrator (ContainerizedOrchestrator)

Orchestrator responsible for running pipelines on Sagemaker.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(ContainerizedOrchestrator):
    """Orchestrator responsible for running pipelines on Sagemaker."""

    @property
    def config(self) -> SagemakerOrchestratorConfig:
        """Returns the `SagemakerOrchestratorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerOrchestratorConfig, self._config)

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        In the remote case, checks that the stack contains a container registry,
        image builder and only remote components.

        Returns:
            A `StackValidator` instance.
        """

        def _validate_remote_components(
            stack: "Stack",
        ) -> Tuple[bool, str]:
            for component in stack.components.values():
                if not component.config.is_local:
                    continue

                return False, (
                    f"The Sagemaker orchestrator runs pipelines remotely, "
                    f"but the '{component.name}' {component.type.value} is "
                    "a local stack component and will not be available in "
                    "the Sagemaker step.\nPlease ensure that you always "
                    "use non-local stack components with the Sagemaker "
                    "orchestrator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_orchestrator_run_id(self) -> str:
        """Returns the run id of the active orchestrator run.

        Important: This needs to be a unique ID and return the same value for
        all steps of a pipeline run.

        Returns:
            The orchestrator run id.

        Raises:
            RuntimeError: If the run id cannot be read from the environment.
        """
        try:
            return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
        except KeyError:
            raise RuntimeError(
                "Unable to read run id from environment variable "
                f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
            )

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Sagemaker orchestrator.

        Returns:
            The settings class.
        """
        return SagemakerOrchestratorSettings

    def prepare_or_run_pipeline(
        self,
        deployment: "PipelineDeploymentResponseModel",
        stack: "Stack",
    ) -> None:
        """Prepares or runs a pipeline on Sagemaker.

        Args:
            deployment: The deployment to prepare or run.
            stack: The stack to run on.
        """
        if deployment.schedule:
            logger.warning(
                "The Sagemaker Orchestrator currently does not support the "
                "use of schedules. The `schedule` will be ignored "
                "and the pipeline will be run immediately."
            )

        orchestrator_run_name = get_orchestrator_run_name(
            pipeline_name=deployment.pipeline_configuration.name
        ).replace("_", "-")

        session = sagemaker.Session(default_bucket=self.config.bucket)

        sagemaker_steps = []
        for step_name, step in deployment.step_configurations.items():
            image = self.get_image(deployment=deployment, step_name=step_name)
            command = StepEntrypointConfiguration.get_entrypoint_command()
            arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
                step_name=step_name, deployment_id=deployment.id
            )
            entrypoint = command + arguments

            step_settings = cast(
                SagemakerOrchestratorSettings, self.get_settings(step)
            )
            processor_role = (
                step_settings.processor_role or self.config.execution_role
            )

            kwargs = (
                {"tags": [step_settings.processor_tags]}
                if step_settings.processor_tags
                else {}
            )

            processor = sagemaker.processing.Processor(
                role=processor_role,
                image_uri=image,
                instance_count=1,
                sagemaker_session=session,
                instance_type=step_settings.instance_type,
                entrypoint=entrypoint,
                base_job_name=orchestrator_run_name,
                env={
                    ENV_ZENML_SAGEMAKER_RUN_ID: ExecutionVariables.PIPELINE_EXECUTION_ARN,
                },
                volume_size_in_gb=step_settings.volume_size_in_gb,
                max_runtime_in_seconds=step_settings.max_runtime_in_seconds,
                **kwargs,
            )

            sagemaker_step = ProcessingStep(
                name=step.config.name,
                processor=processor,
                depends_on=step.spec.upstream_steps,
            )
            sagemaker_steps.append(sagemaker_step)

        # construct the pipeline from the sagemaker_steps
        pipeline = Pipeline(
            name=orchestrator_run_name,
            steps=sagemaker_steps,
            sagemaker_session=session,
        )

        pipeline.create(role_arn=self.config.execution_role)
        pipeline_execution = pipeline.start()

        # mainly for testing purposes, we wait for the pipeline to finish
        if self.config.synchronous:
            logger.info(
                "Executing synchronously. Waiting for pipeline to finish..."
            )
            pipeline_execution.wait(
                delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
            )
            logger.info("Pipeline completed successfully.")

    def get_pipeline_run_metadata(
        self, run_id: UUID
    ) -> Dict[str, "MetadataType"]:
        """Get general component-specific metadata for a pipeline run.

        Args:
            run_id: The ID of the pipeline run.

        Returns:
            A dictionary of metadata.
        """
        # TODO: Add this once we can get the region
        # run_url = (
        #     f"https://{region}.console.aws.amazon.com/sagemaker/"
        #     f"home?region={region}"
        # )
        return {
            "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
        }
config: SagemakerOrchestratorConfig property readonly

Returns the SagemakerOrchestratorConfig config.

Returns:

Type Description
SagemakerOrchestratorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Sagemaker orchestrator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

In the remote case, checks that the stack contains a container registry, image builder and only remote components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

get_orchestrator_run_id(self)

Returns the run id of the active orchestrator run.

Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.

Returns:

Type Description
str

The orchestrator run id.

Exceptions:

Type Description
RuntimeError

If the run id cannot be read from the environment.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
    """Returns the run id of the active orchestrator run.

    Important: This needs to be a unique ID and return the same value for
    all steps of a pipeline run.

    Returns:
        The orchestrator run id.

    Raises:
        RuntimeError: If the run id cannot be read from the environment.
    """
    try:
        return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
    except KeyError:
        raise RuntimeError(
            "Unable to read run id from environment variable "
            f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
        )
get_pipeline_run_metadata(self, run_id)

Get general component-specific metadata for a pipeline run.

Parameters:

Name Type Description Default
run_id UUID

The ID of the pipeline run.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
    self, run_id: UUID
) -> Dict[str, "MetadataType"]:
    """Get general component-specific metadata for a pipeline run.

    Args:
        run_id: The ID of the pipeline run.

    Returns:
        A dictionary of metadata.
    """
    # TODO: Add this once we can get the region
    # run_url = (
    #     f"https://{region}.console.aws.amazon.com/sagemaker/"
    #     f"home?region={region}"
    # )
    return {
        "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
    }
prepare_or_run_pipeline(self, deployment, stack)

Prepares or runs a pipeline on Sagemaker.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponseModel

The deployment to prepare or run.

required
stack Stack

The stack to run on.

required
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
    self,
    deployment: "PipelineDeploymentResponseModel",
    stack: "Stack",
) -> None:
    """Prepares or runs a pipeline on Sagemaker.

    Args:
        deployment: The deployment to prepare or run.
        stack: The stack to run on.
    """
    if deployment.schedule:
        logger.warning(
            "The Sagemaker Orchestrator currently does not support the "
            "use of schedules. The `schedule` will be ignored "
            "and the pipeline will be run immediately."
        )

    orchestrator_run_name = get_orchestrator_run_name(
        pipeline_name=deployment.pipeline_configuration.name
    ).replace("_", "-")

    session = sagemaker.Session(default_bucket=self.config.bucket)

    sagemaker_steps = []
    for step_name, step in deployment.step_configurations.items():
        image = self.get_image(deployment=deployment, step_name=step_name)
        command = StepEntrypointConfiguration.get_entrypoint_command()
        arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
            step_name=step_name, deployment_id=deployment.id
        )
        entrypoint = command + arguments

        step_settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(step)
        )
        processor_role = (
            step_settings.processor_role or self.config.execution_role
        )

        kwargs = (
            {"tags": [step_settings.processor_tags]}
            if step_settings.processor_tags
            else {}
        )

        processor = sagemaker.processing.Processor(
            role=processor_role,
            image_uri=image,
            instance_count=1,
            sagemaker_session=session,
            instance_type=step_settings.instance_type,
            entrypoint=entrypoint,
            base_job_name=orchestrator_run_name,
            env={
                ENV_ZENML_SAGEMAKER_RUN_ID: ExecutionVariables.PIPELINE_EXECUTION_ARN,
            },
            volume_size_in_gb=step_settings.volume_size_in_gb,
            max_runtime_in_seconds=step_settings.max_runtime_in_seconds,
            **kwargs,
        )

        sagemaker_step = ProcessingStep(
            name=step.config.name,
            processor=processor,
            depends_on=step.spec.upstream_steps,
        )
        sagemaker_steps.append(sagemaker_step)

    # construct the pipeline from the sagemaker_steps
    pipeline = Pipeline(
        name=orchestrator_run_name,
        steps=sagemaker_steps,
        sagemaker_session=session,
    )

    pipeline.create(role_arn=self.config.execution_role)
    pipeline_execution = pipeline.start()

    # mainly for testing purposes, we wait for the pipeline to finish
    if self.config.synchronous:
        logger.info(
            "Executing synchronously. Waiting for pipeline to finish..."
        )
        pipeline_execution.wait(
            delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
        )
        logger.info("Pipeline completed successfully.")

secrets_managers special

AWS Secrets Manager.

aws_secrets_manager

Implementation of the AWS Secrets Manager integration.

AWSSecretsManager (BaseSecretsManager)

Class to interact with the AWS secrets manager.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
    """Class to interact with the AWS secrets manager."""

    CLIENT: ClassVar[Any] = None

    @property
    def config(self) -> AWSSecretsManagerConfig:
        """Returns the `AWSSecretsManagerConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSSecretsManagerConfig, self._config)

    @classmethod
    def _ensure_client_connected(cls, region_name: str) -> None:
        """Ensure that the client is connected to the AWS secrets manager.

        Args:
            region_name: the AWS region name
        """
        if cls.CLIENT is None:
            # Create a Secrets Manager client
            session = boto3.session.Session()
            cls.CLIENT = session.client(
                service_name="secretsmanager", region_name=region_name
            )

    def _get_secret_tags(
        self, secret: BaseSecretSchema
    ) -> List[Dict[str, str]]:
        """Return a list of AWS secret tag values for a given secret.

        Args:
            secret: the secret object

        Returns:
            A list of AWS secret tag values
        """
        metadata = self._get_secret_metadata(secret)
        return [{"Key": k, "Value": v} for k, v in metadata.items()]

    def _get_secret_scope_filters(
        self,
        secret_name: Optional[str] = None,
    ) -> List[Dict[str, Any]]:
        """Return a list of AWS filters for the entire scope or just a scoped secret.

        These filters can be used when querying the AWS Secrets Manager
        for all secrets or for a single secret available in the configured
        scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html

        Example AWS filters for all secrets in the current (namespace) scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Example AWS filters for a particular secret in the current (namespace)
        scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_secret_name"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_secret"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of AWS filters uniquely identifying all secrets
            or a named secret within the configured scope.
        """
        metadata = self._get_secret_scope_metadata(secret_name)
        filters: List[Dict[str, Any]] = []
        for k, v in metadata.items():
            filters.append(
                {
                    "Key": "tag-key",
                    "Values": [
                        k,
                    ],
                }
            )
            filters.append(
                {
                    "Key": "tag-value",
                    "Values": [
                        str(v),
                    ],
                }
            )

        return filters

    def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
        """List all secrets matching a name.

        This method lists all the secrets in the current scope without loading
        their contents. An optional secret name can be supplied to filter out
        all but a single secret identified by name.

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of secret names in the current scope and the optional
            secret name.
        """
        self._ensure_client_connected(self.config.region_name)

        filters: List[Dict[str, Any]] = []
        prefix: Optional[str] = None
        if self.config.scope == SecretsManagerScope.NONE:
            # unscoped (legacy) secrets don't have tags. We want to filter out
            # non-legacy secrets
            filters = [
                {
                    "Key": "tag-key",
                    "Values": [
                        "!zenml_scope",
                    ],
                },
            ]
            if secret_name:
                prefix = secret_name
        else:
            filters = self._get_secret_scope_filters()
            if secret_name:
                prefix = self._get_scoped_secret_name(secret_name)
            else:
                # add the name prefix to the filters to account for the fact
                # that AWS does not do exact matching but prefix-matching on the
                # filters
                prefix = self._get_scoped_secret_name_prefix()

        if prefix:
            filters.append(
                {
                    "Key": "name",
                    "Values": [
                        f"{prefix}",
                    ],
                }
            )

        paginator = self.CLIENT.get_paginator(_BOTO_CLIENT_LIST_SECRETS)
        pages = paginator.paginate(
            Filters=filters,
            PaginationConfig={
                "PageSize": 100,
            },
        )
        results = []
        for page in pages:
            for secret in page[_PAGINATOR_RESPONSE_SECRETS_LIST_KEY]:
                name = self._get_unscoped_secret_name(secret["Name"])
                # keep only the names that are in scope and filter by secret name,
                # if one was given
                if name and (not secret_name or secret_name == name):
                    results.append(name)

        return results

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        validate_aws_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.config.region_name)

        if self._list_secrets(secret.name):
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        secret_value = json.dumps(secret_to_dict(secret, encode=False))
        kwargs: Dict[str, Any] = {
            "Name": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
            "Tags": self._get_secret_tags(secret),
        }

        self.CLIENT.create_secret(**kwargs)

        logger.debug("Created AWS secret: %s", kwargs["Name"])

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets a secret.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
        """
        validate_aws_secret_name_or_namespace(secret_name)
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        get_secret_value_response = self.CLIENT.get_secret_value(
            SecretId=self._get_scoped_secret_name(secret_name)
        )
        if "SecretString" not in get_secret_value_response:
            get_secret_value_response = None

        return secret_from_dict(
            json.loads(get_secret_value_response["SecretString"]),
            secret_name=secret_name,
            decode=False,
        )

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        return self._list_secrets()

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        validate_aws_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret.name):
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        secret_value = json.dumps(secret_to_dict(secret))

        kwargs = {
            "SecretId": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
        }

        self.CLIENT.put_secret_value(**kwargs)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret does not exist
        """
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets.

        This method will force delete all your secrets. You will not be able to
        recover them once this method is called.
        """
        self._ensure_client_connected(self.config.region_name)
        for secret_name in self._list_secrets():
            self.CLIENT.delete_secret(
                SecretId=self._get_scoped_secret_name(secret_name),
                ForceDeleteWithoutRecovery=True,
            )
config: AWSSecretsManagerConfig property readonly

Returns the AWSSecretsManagerConfig config.

Returns:

Type Description
AWSSecretsManagerConfig

The configuration.

delete_all_secrets(self)

Delete all existing secrets.

This method will force delete all your secrets. You will not be able to recover them once this method is called.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets.

    This method will force delete all your secrets. You will not be able to
    recover them once this method is called.
    """
    self._ensure_client_connected(self.config.region_name)
    for secret_name in self._list_secrets():
        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret does not exist
    """
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    self.CLIENT.delete_secret(
        SecretId=self._get_scoped_secret_name(secret_name),
        ForceDeleteWithoutRecovery=True,
    )
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    return self._list_secrets()
get_secret(self, secret_name)

Gets a secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets a secret.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
    """
    validate_aws_secret_name_or_namespace(secret_name)
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    get_secret_value_response = self.CLIENT.get_secret_value(
        SecretId=self._get_scoped_secret_name(secret_name)
    )
    if "SecretString" not in get_secret_value_response:
        get_secret_value_response = None

    return secret_from_dict(
        json.loads(get_secret_value_response["SecretString"]),
        secret_name=secret_name,
        decode=False,
    )
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    validate_aws_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.config.region_name)

    if self._list_secrets(secret.name):
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    secret_value = json.dumps(secret_to_dict(secret, encode=False))
    kwargs: Dict[str, Any] = {
        "Name": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
        "Tags": self._get_secret_tags(secret),
    }

    self.CLIENT.create_secret(**kwargs)

    logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    validate_aws_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret.name):
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    secret_value = json.dumps(secret_to_dict(secret))

    kwargs = {
        "SecretId": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
    }

    self.CLIENT.put_secret_value(**kwargs)

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator)

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.
    """

    @property
    def config(self) -> SagemakerStepOperatorConfig:
        """Returns the `SagemakerStepOperatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerStepOperatorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the SageMaker step operator.

        Returns:
            The settings class.
        """
        return SagemakerStepOperatorSettings

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        Returns:
            A validator that checks that the stack contains a remote container
            registry and a remote artifact store.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            if stack.artifact_store.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to write files into the artifact store, but the "
                    f"artifact store `{stack.artifact_store.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote artifact store when using the SageMaker "
                    "step operator."
                )

            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to push/pull Docker images, but the "
                    f"container registry `{container_registry.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote container registry when using the "
                    "SageMaker step operator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_docker_builds(
        self, deployment: "PipelineDeploymentBaseModel"
    ) -> List["BuildConfiguration"]:
        """Gets the Docker builds required for the component.

        Args:
            deployment: The pipeline deployment for which to get the builds.

        Returns:
            The required Docker builds.
        """
        builds = []
        for step_name, step in deployment.step_configurations.items():
            if step.config.step_operator == self.name:
                build = BuildConfiguration(
                    key=SAGEMAKER_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                    entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
                )
                builds.append(build)

        return builds

    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
    ) -> None:
        """Launches a step on SageMaker.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
        """
        if not info.config.resource_settings.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the SageMaker step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different instance type like this: "
                "`zenml step-operator update %s "
                "--instance_type=<INSTANCE_TYPE>`",
                self.name,
            )

        image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
        environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}

        settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

        session = sagemaker.Session(default_bucket=self.config.bucket)
        instance_type = settings.instance_type or "ml.m5.large"
        estimator = sagemaker.estimator.Estimator(
            image_name,
            self.config.role,
            environment=environment,
            instance_count=1,
            instance_type=instance_type,
            sagemaker_session=session,
        )

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_run_name = info.run_name.replace("_", "-")

        experiment_config = {}
        if settings.experiment_name:
            experiment_config = {
                "ExperimentName": settings.experiment_name,
                "TrialName": sanitized_run_name,
            }

        estimator.fit(
            wait=True,
            experiment_config=experiment_config,
            job_name=sanitized_run_name,
        )
config: SagemakerStepOperatorConfig property readonly

Returns the SagemakerStepOperatorConfig config.

Returns:

Type Description
SagemakerStepOperatorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the SageMaker step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a remote container registry and a remote artifact store.

get_docker_builds(self, deployment)

Gets the Docker builds required for the component.

Parameters:

Name Type Description Default
deployment PipelineDeploymentBaseModel

The pipeline deployment for which to get the builds.

required

Returns:

Type Description
List[BuildConfiguration]

The required Docker builds.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def get_docker_builds(
    self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
    """Gets the Docker builds required for the component.

    Args:
        deployment: The pipeline deployment for which to get the builds.

    Returns:
        The required Docker builds.
    """
    builds = []
    for step_name, step in deployment.step_configurations.items():
        if step.config.step_operator == self.name:
            build = BuildConfiguration(
                key=SAGEMAKER_DOCKER_IMAGE_KEY,
                settings=step.config.docker_settings,
                step_name=step_name,
                entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
            )
            builds.append(build)

    return builds
launch(self, info, entrypoint_command)

Launches a step on SageMaker.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
) -> None:
    """Launches a step on SageMaker.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
    """
    if not info.config.resource_settings.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
    environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}

    settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

    session = sagemaker.Session(default_bucket=self.config.bucket)
    instance_type = settings.instance_type or "ml.m5.large"
    estimator = sagemaker.estimator.Estimator(
        image_name,
        self.config.role,
        environment=environment,
        instance_count=1,
        instance_type=instance_type,
        sagemaker_session=session,
    )

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_run_name = info.run_name.replace("_", "-")

    experiment_config = {}
    if settings.experiment_name:
        experiment_config = {
            "ExperimentName": settings.experiment_name,
            "TrialName": sanitized_run_name,
        }

    estimator.fit(
        wait=True,
        experiment_config=experiment_config,
        job_name=sanitized_run_name,
    )