Skip to content

Aws

zenml.integrations.aws

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

Attributes

AWS = 'aws' module-attribute

AWS_CONNECTOR_TYPE = 'aws' module-attribute

AWS_CONTAINER_REGISTRY_FLAVOR = 'aws' module-attribute

AWS_IMAGE_BUILDER_FLAVOR = 'aws' module-attribute

AWS_RESOURCE_TYPE = 'aws-generic' module-attribute

AWS_SAGEMAKER_ORCHESTRATOR_FLAVOR = 'sagemaker' module-attribute

AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR = 'sagemaker' module-attribute

AWS_SECRET_MANAGER_FLAVOR = 'aws' module-attribute

S3_RESOURCE_TYPE = 's3-bucket' module-attribute

Classes

AWSIntegration

Bases: Integration

Definition of AWS integration for ZenML.

Functions
activate() -> None classmethod

Activate the AWS integration.

Source code in src/zenml/integrations/aws/__init__.py
50
51
52
53
@classmethod
def activate(cls) -> None:
    """Activate the AWS integration."""
    from zenml.integrations.aws import service_connectors  # noqa
flavors() -> List[Type[Flavor]] classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[Type[Flavor]]

List of stack component flavors for this integration.

Source code in src/zenml/integrations/aws/__init__.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.aws.flavors import (
        AWSContainerRegistryFlavor,
        AWSImageBuilderFlavor,
        SagemakerOrchestratorFlavor,
        SagemakerStepOperatorFlavor,
    )

    return [
        AWSContainerRegistryFlavor,
        AWSImageBuilderFlavor,
        SagemakerStepOperatorFlavor,
        SagemakerOrchestratorFlavor,
    ]

Flavor

Class for ZenML Flavors.

Attributes
config_class: Type[StackComponentConfig] abstractmethod property

Returns StackComponentConfig config class.

Returns:

Type Description
Type[StackComponentConfig]

The config class.

config_schema: Dict[str, Any] property

The config schema for a flavor.

Returns:

Type Description
Dict[str, Any]

The config schema.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[StackComponent] abstractmethod property

Implementation class for this flavor.

Returns:

Type Description
Type[StackComponent]

The implementation class for this flavor.

logo_url: Optional[str] property

A url to represent the flavor in the dashboard.

Returns:

Type Description
Optional[str]

The flavor logo.

name: str abstractmethod property

The flavor name.

Returns:

Type Description
str

The flavor name.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

type: StackComponentType abstractmethod property

The stack component type.

Returns:

Type Description
StackComponentType

The stack component type.

Functions
from_model(flavor_model: FlavorResponse) -> Flavor classmethod

Loads a flavor from a model.

Parameters:

Name Type Description Default
flavor_model FlavorResponse

The model to load from.

required

Raises:

Type Description
CustomFlavorImportError

If the custom flavor can't be imported.

ImportError

If the flavor can't be imported.

Returns:

Type Description
Flavor

The loaded flavor.

Source code in src/zenml/stack/flavor.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@classmethod
def from_model(cls, flavor_model: FlavorResponse) -> "Flavor":
    """Loads a flavor from a model.

    Args:
        flavor_model: The model to load from.

    Raises:
        CustomFlavorImportError: If the custom flavor can't be imported.
        ImportError: If the flavor can't be imported.

    Returns:
        The loaded flavor.
    """
    try:
        flavor = source_utils.load(flavor_model.source)()
    except (ModuleNotFoundError, ImportError, NotImplementedError) as err:
        if flavor_model.is_custom:
            flavor_module, _ = flavor_model.source.rsplit(".", maxsplit=1)
            expected_file_path = os.path.join(
                source_utils.get_source_root(),
                flavor_module.replace(".", os.path.sep),
            )
            raise CustomFlavorImportError(
                f"Couldn't import custom flavor {flavor_model.name}: "
                f"{err}. Make sure the custom flavor class "
                f"`{flavor_model.source}` is importable. If it is part of "
                "a library, make sure it is installed. If "
                "it is a local code file, make sure it exists at "
                f"`{expected_file_path}.py`."
            )
        else:
            raise ImportError(
                f"Couldn't import flavor {flavor_model.name}: {err}"
            )
    return cast(Flavor, flavor)
generate_default_docs_url() -> str

Generate the doc urls for all inbuilt and integration flavors.

Note that this method is not going to be useful for custom flavors, which do not have any docs in the main zenml docs.

Returns:

Type Description
str

The complete url to the zenml documentation

Source code in src/zenml/stack/flavor.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def generate_default_docs_url(self) -> str:
    """Generate the doc urls for all inbuilt and integration flavors.

    Note that this method is not going to be useful for custom flavors,
    which do not have any docs in the main zenml docs.

    Returns:
        The complete url to the zenml documentation
    """
    from zenml import __version__

    component_type = self.type.plural.replace("_", "-")
    name = self.name.replace("_", "-")

    try:
        is_latest = is_latest_zenml_version()
    except RuntimeError:
        # We assume in error cases that we are on the latest version
        is_latest = True

    if is_latest:
        base = "https://docs.zenml.io"
    else:
        base = f"https://zenml-io.gitbook.io/zenml-legacy-documentation/v/{__version__}"
    return f"{base}/stack-components/{component_type}/{name}"
generate_default_sdk_docs_url() -> str

Generate SDK docs url for a flavor.

Returns:

Type Description
str

The complete url to the zenml SDK docs

Source code in src/zenml/stack/flavor.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def generate_default_sdk_docs_url(self) -> str:
    """Generate SDK docs url for a flavor.

    Returns:
        The complete url to the zenml SDK docs
    """
    from zenml import __version__

    base = f"https://sdkdocs.zenml.io/{__version__}"

    component_type = self.type.plural

    if "zenml.integrations" in self.__module__:
        # Get integration name out of module path which will look something
        #  like this "zenml.integrations.<integration>....
        integration = self.__module__.split(
            "zenml.integrations.", maxsplit=1
        )[1].split(".")[0]

        return (
            f"{base}/integration_code_docs"
            f"/integrations-{integration}/#{self.__module__}"
        )

    else:
        return (
            f"{base}/core_code_docs/core-{component_type}/"
            f"#{self.__module__}"
        )
to_model(integration: Optional[str] = None, is_custom: bool = True) -> FlavorRequest

Converts a flavor to a model.

Parameters:

Name Type Description Default
integration Optional[str]

The integration to use for the model.

None
is_custom bool

Whether the flavor is a custom flavor.

True

Returns:

Type Description
FlavorRequest

The model.

Source code in src/zenml/stack/flavor.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def to_model(
    self,
    integration: Optional[str] = None,
    is_custom: bool = True,
) -> FlavorRequest:
    """Converts a flavor to a model.

    Args:
        integration: The integration to use for the model.
        is_custom: Whether the flavor is a custom flavor.

    Returns:
        The model.
    """
    connector_requirements = self.service_connector_requirements
    connector_type = (
        connector_requirements.connector_type
        if connector_requirements
        else None
    )
    resource_type = (
        connector_requirements.resource_type
        if connector_requirements
        else None
    )
    resource_id_attr = (
        connector_requirements.resource_id_attr
        if connector_requirements
        else None
    )

    model = FlavorRequest(
        name=self.name,
        type=self.type,
        source=source_utils.resolve(self.__class__).import_path,
        config_schema=self.config_schema,
        connector_type=connector_type,
        connector_resource_type=resource_type,
        connector_resource_id_attr=resource_id_attr,
        integration=integration,
        logo_url=self.logo_url,
        docs_url=self.docs_url,
        sdk_docs_url=self.sdk_docs_url,
        is_custom=is_custom,
    )
    return model

Integration

Base class for integration in ZenML.

Functions
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
175
176
177
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for r in cls.get_requirements():
        try:
            # First check if the base package is installed
            dist = pkg_resources.get_distribution(r)

            # Next, check if the dependencies (including extras) are
            # installed
            deps: List[Requirement] = []

            _, extras = parse_requirement(r)
            if extras:
                extra_list = extras[1:-1].split(",")
                for extra in extra_list:
                    try:
                        requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                    except pkg_resources.UnknownExtra as e:
                        logger.debug(f"Unknown extra: {str(e)}")
                        return False
                    deps.extend(requirements)
            else:
                deps = dist.requires()

            for ri in deps:
                try:
                    # Remove the "extra == ..." part from the requirement string
                    cleaned_req = re.sub(
                        r"; extra == \"\w+\"", "", str(ri)
                    )
                    pkg_resources.get_distribution(cleaned_req)
                except pkg_resources.DistributionNotFound as e:
                    logger.debug(
                        f"Unable to find required dependency "
                        f"'{e.req}' for requirement '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False
                except pkg_resources.VersionConflict as e:
                    logger.debug(
                        f"Package version '{e.dist}' does not match "
                        f"version '{e.req}' required by '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False

        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"Package version '{e.dist}' does not match version "
                f"'{e.req}' necessary for integration {cls.NAME}."
            )
            return False

    logger.debug(
        f"Integration {cls.NAME} is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
179
180
181
182
183
184
185
186
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@classmethod
def get_requirements(
    cls,
    target_os: Optional[str] = None,
    python_version: Optional[str] = None,
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret
plugin_flavors() -> List[Type[BasePluginFlavor]] classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in src/zenml/integrations/integration.py
188
189
190
191
192
193
194
195
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

Modules

container_registries

Initialization of AWS Container Registry integration.

Classes
AWSContainerRegistry(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseContainerRegistry

Class for AWS Container Registry.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: AWSContainerRegistryConfig property

Returns the AWSContainerRegistryConfig config.

Returns:

Type Description
AWSContainerRegistryConfig

The configuration.

post_registration_message: Optional[str] property

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

Functions
prepare_image_push(image_name: str) -> None

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Raises:

Type Description
ValueError

If the docker image name is invalid.

Source code in src/zenml/integrations/aws/container_registries/aws_container_registry.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    # Find repository name from image name
    match = re.search(f"{self.config.uri}/(.*):.*", image_name)
    if not match:
        raise ValueError(f"Invalid docker image name '{image_name}'.")
    repo_name = match.group(1)

    client = self._get_ecr_client()
    try:
        response = client.describe_repositories()
    except (BotoCoreError, ClientError):
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could not find any repositories because "
            "your local AWS credentials are not set. We will try to push "
            "anyway, but in case it fails you need to create a repository "
            f"named `{repo_name}`."
        )
        return

    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(
        image_name.startswith(f"{uri}:") for uri in repo_uris
    )
    if not repo_exists:
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )
Modules
aws_container_registry

Implementation of the AWS container registry integration.

Classes
AWSContainerRegistry(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseContainerRegistry

Class for AWS Container Registry.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: AWSContainerRegistryConfig property

Returns the AWSContainerRegistryConfig config.

Returns:

Type Description
AWSContainerRegistryConfig

The configuration.

post_registration_message: Optional[str] property

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

Functions
prepare_image_push(image_name: str) -> None

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Raises:

Type Description
ValueError

If the docker image name is invalid.

Source code in src/zenml/integrations/aws/container_registries/aws_container_registry.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    # Find repository name from image name
    match = re.search(f"{self.config.uri}/(.*):.*", image_name)
    if not match:
        raise ValueError(f"Invalid docker image name '{image_name}'.")
    repo_name = match.group(1)

    client = self._get_ecr_client()
    try:
        response = client.describe_repositories()
    except (BotoCoreError, ClientError):
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could not find any repositories because "
            "your local AWS credentials are not set. We will try to push "
            "anyway, but in case it fails you need to create a repository "
            f"named `{repo_name}`."
        )
        return

    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(
        image_name.startswith(f"{uri}:") for uri in repo_uris
    )
    if not repo_exists:
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )
Functions

flavors

AWS integration flavors.

Classes
AWSContainerRegistryConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseContainerRegistryConfig

Configuration for AWS Container Registry.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Functions
validate_aws_uri(uri: str) -> str classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Raises:

Type Description
ValueError

If the URI contains a slash character.

Source code in src/zenml/integrations/aws/flavors/aws_container_registry_flavor.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@field_validator("uri")
@classmethod
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri
AWSContainerRegistryFlavor

Bases: BaseContainerRegistryFlavor

AWS Container Registry flavor.

Attributes
config_class: Type[AWSContainerRegistryConfig] property

Config class for this flavor.

Returns:

Type Description
Type[AWSContainerRegistryConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSContainerRegistry] property

Implementation class.

Returns:

Type Description
Type[AWSContainerRegistry]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

AWSImageBuilderConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseImageBuilderConfig

AWS Code Build image builder configuration.

Attributes:

Name Type Description
code_build_project str

The name of an existing AWS CodeBuild project to use to build the image. The CodeBuild project must exist in the AWS account and region inferred from the AWS service connector credentials or implicitly from the local AWS config.

build_image str

The Docker image to use for the AWS CodeBuild environment. The image must have Docker installed and be able to run Docker commands. The default image is bentolor/docker-dind-awscli. This can be customized to use a mirror, if needed, in case the Dockerhub image is not accessible or rate-limited.

custom_env_vars Optional[Dict[str, str]]

Custom environment variables to pass to the AWS CodeBuild build.

compute_type str

The compute type to use for the AWS CodeBuild build. The default is BUILD_GENERAL1_SMALL.

implicit_container_registry_auth bool

Whether to use implicit authentication to authenticate the AWS Code Build build to the container registry when pushing container images. If set to False, the container registry credentials must be explicitly configured for the container registry stack component or the container registry stack component must be linked to a service connector. NOTE: When implicit_container_registry_auth is set to False, the container registry credentials will be passed to the AWS Code Build build as environment variables. This is not recommended for production use unless your service connector is configured to generate short-lived credentials.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
AWSImageBuilderFlavor

Bases: BaseImageBuilderFlavor

AWS Code Build image builder flavor.

Attributes
config_class: Type[BaseImageBuilderConfig] property

The config class.

Returns:

Type Description
Type[BaseImageBuilderConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSImageBuilder] property

Implementation class.

Returns:

Type Description
Type[AWSImageBuilder]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

The flavor name.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

SagemakerOrchestratorConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseOrchestratorConfig, SagemakerOrchestratorSettings

Config for the Sagemaker orchestrator.

There are three ways to authenticate to AWS: - By connecting a ServiceConnector to the orchestrator, - By configuring explicit AWS credentials aws_access_key_id, aws_secret_access_key, and optional aws_auth_role_arn, - If none of the above are provided, unspecified credentials will be loaded from the default AWS config.

Attributes:

Name Type Description
execution_role str

The IAM role ARN to use for the pipeline.

scheduler_role Optional[str]

The ARN of the IAM role that will be assumed by the EventBridge service to launch Sagemaker pipelines (For more details regarding the required permissions, please check: https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules)

aws_access_key_id Optional[str]

The AWS access key ID to use to authenticate to AWS. If not provided, the value from the default AWS config will be used.

aws_secret_access_key Optional[str]

The AWS secret access key to use to authenticate to AWS. If not provided, the value from the default AWS config will be used.

aws_profile Optional[str]

The AWS profile to use for authentication if not using service connectors or explicit credentials. If not provided, the default profile will be used.

aws_auth_role_arn Optional[str]

The ARN of an intermediate IAM role to assume when authenticating to AWS.

region Optional[str]

The AWS region where the processing job will be run. If not provided, the value from the default AWS config will be used.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_remote: bool property

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

is_schedulable: bool property

Whether the orchestrator is schedulable or not.

Returns:

Type Description
bool

Whether the orchestrator is schedulable or not.

is_synchronous: bool property

Whether the orchestrator runs synchronous or not.

Returns:

Type Description
bool

Whether the orchestrator runs synchronous or not.

SagemakerOrchestratorFlavor

Bases: BaseOrchestratorFlavor

Flavor for the Sagemaker orchestrator.

Attributes
config_class: Type[SagemakerOrchestratorConfig] property

Returns SagemakerOrchestratorConfig config class.

Returns:

Type Description
Type[SagemakerOrchestratorConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerOrchestrator] property

Implementation class.

Returns:

Type Description
Type[SagemakerOrchestrator]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

SagemakerStepOperatorConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseStepOperatorConfig, SagemakerStepOperatorSettings

Config for the Sagemaker step operator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_remote: bool property

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerStepOperatorFlavor

Bases: BaseStepOperatorFlavor

Flavor for the Sagemaker step operator.

Attributes
config_class: Type[SagemakerStepOperatorConfig] property

Returns SagemakerStepOperatorConfig config class.

Returns:

Type Description
Type[SagemakerStepOperatorConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerStepOperator] property

Implementation class.

Returns:

Type Description
Type[SagemakerStepOperator]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

Modules
aws_container_registry_flavor

AWS container registry flavor.

Classes
AWSContainerRegistryConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseContainerRegistryConfig

Configuration for AWS Container Registry.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Functions
validate_aws_uri(uri: str) -> str classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Raises:

Type Description
ValueError

If the URI contains a slash character.

Source code in src/zenml/integrations/aws/flavors/aws_container_registry_flavor.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@field_validator("uri")
@classmethod
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri
AWSContainerRegistryFlavor

Bases: BaseContainerRegistryFlavor

AWS Container Registry flavor.

Attributes
config_class: Type[AWSContainerRegistryConfig] property

Config class for this flavor.

Returns:

Type Description
Type[AWSContainerRegistryConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSContainerRegistry] property

Implementation class.

Returns:

Type Description
Type[AWSContainerRegistry]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

aws_image_builder_flavor

AWS Code Build image builder flavor.

Classes
AWSImageBuilderConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseImageBuilderConfig

AWS Code Build image builder configuration.

Attributes:

Name Type Description
code_build_project str

The name of an existing AWS CodeBuild project to use to build the image. The CodeBuild project must exist in the AWS account and region inferred from the AWS service connector credentials or implicitly from the local AWS config.

build_image str

The Docker image to use for the AWS CodeBuild environment. The image must have Docker installed and be able to run Docker commands. The default image is bentolor/docker-dind-awscli. This can be customized to use a mirror, if needed, in case the Dockerhub image is not accessible or rate-limited.

custom_env_vars Optional[Dict[str, str]]

Custom environment variables to pass to the AWS CodeBuild build.

compute_type str

The compute type to use for the AWS CodeBuild build. The default is BUILD_GENERAL1_SMALL.

implicit_container_registry_auth bool

Whether to use implicit authentication to authenticate the AWS Code Build build to the container registry when pushing container images. If set to False, the container registry credentials must be explicitly configured for the container registry stack component or the container registry stack component must be linked to a service connector. NOTE: When implicit_container_registry_auth is set to False, the container registry credentials will be passed to the AWS Code Build build as environment variables. This is not recommended for production use unless your service connector is configured to generate short-lived credentials.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
AWSImageBuilderFlavor

Bases: BaseImageBuilderFlavor

AWS Code Build image builder flavor.

Attributes
config_class: Type[BaseImageBuilderConfig] property

The config class.

Returns:

Type Description
Type[BaseImageBuilderConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSImageBuilder] property

Implementation class.

Returns:

Type Description
Type[AWSImageBuilder]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

The flavor name.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

sagemaker_orchestrator_flavor

Amazon SageMaker orchestrator flavor.

Classes
SagemakerOrchestratorConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseOrchestratorConfig, SagemakerOrchestratorSettings

Config for the Sagemaker orchestrator.

There are three ways to authenticate to AWS: - By connecting a ServiceConnector to the orchestrator, - By configuring explicit AWS credentials aws_access_key_id, aws_secret_access_key, and optional aws_auth_role_arn, - If none of the above are provided, unspecified credentials will be loaded from the default AWS config.

Attributes:

Name Type Description
execution_role str

The IAM role ARN to use for the pipeline.

scheduler_role Optional[str]

The ARN of the IAM role that will be assumed by the EventBridge service to launch Sagemaker pipelines (For more details regarding the required permissions, please check: https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules)

aws_access_key_id Optional[str]

The AWS access key ID to use to authenticate to AWS. If not provided, the value from the default AWS config will be used.

aws_secret_access_key Optional[str]

The AWS secret access key to use to authenticate to AWS. If not provided, the value from the default AWS config will be used.

aws_profile Optional[str]

The AWS profile to use for authentication if not using service connectors or explicit credentials. If not provided, the default profile will be used.

aws_auth_role_arn Optional[str]

The ARN of an intermediate IAM role to assume when authenticating to AWS.

region Optional[str]

The AWS region where the processing job will be run. If not provided, the value from the default AWS config will be used.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_remote: bool property

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

is_schedulable: bool property

Whether the orchestrator is schedulable or not.

Returns:

Type Description
bool

Whether the orchestrator is schedulable or not.

is_synchronous: bool property

Whether the orchestrator runs synchronous or not.

Returns:

Type Description
bool

Whether the orchestrator runs synchronous or not.

SagemakerOrchestratorFlavor

Bases: BaseOrchestratorFlavor

Flavor for the Sagemaker orchestrator.

Attributes
config_class: Type[SagemakerOrchestratorConfig] property

Returns SagemakerOrchestratorConfig config class.

Returns:

Type Description
Type[SagemakerOrchestratorConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerOrchestrator] property

Implementation class.

Returns:

Type Description
Type[SagemakerOrchestrator]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

SagemakerOrchestratorSettings(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseSettings

Settings for the Sagemaker orchestrator.

Attributes:

Name Type Description
synchronous bool

If True, the client running a pipeline using this orchestrator waits until all steps finish running. If False, the client returns immediately and the pipeline is executed asynchronously. Defaults to True.

instance_type Optional[str]

The instance type to use for the processing job.

execution_role Optional[str]

The IAM role to use for the step execution.

processor_role Optional[str]

DEPRECATED: use execution_role instead.

volume_size_in_gb int

The size of the EBS volume to use for the processing job.

max_runtime_in_seconds int

The maximum runtime in seconds for the processing job.

tags Dict[str, str]

Tags to apply to the Processor/Estimator assigned to the step.

pipeline_tags Dict[str, str]

Tags to apply to the pipeline via the sagemaker.workflow.pipeline.Pipeline.create method.

processor_tags Dict[str, str]

DEPRECATED: use tags instead.

keep_alive_period_in_seconds Optional[int]

The time in seconds after which the provisioned instance will be terminated if not used. This is only applicable for TrainingStep type and it is not possible to use TrainingStep type if the output_data_s3_uri is set to Dict[str, str].

use_training_step Optional[bool]

Whether to use the TrainingStep type. It is not possible to use TrainingStep type if the output_data_s3_uri is set to Dict[str, str] or if the output_data_s3_mode != "EndOfJob".

processor_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Processor for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor for a full list of arguments. For processor_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

environment Dict[str, str]

Environment variables to pass to the container.

estimator_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Estimator for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For a list of available instance types, check https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html.

input_data_s3_mode str

How data is made available to the container. Two possible input modes: File, Pipe.

input_data_s3_uri Optional[Union[str, Dict[str, str]]]

S3 URI where data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with input_data_s3_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3.

output_data_s3_mode str

How data is uploaded to the S3 bucket. Two possible output modes: EndOfJob, Continuous.

output_data_s3_uri Optional[Union[str, Dict[str, str]]]

S3 URI where data is uploaded after or during processing run. e.g. s3://my-bucket/my-data/output. How data will be made available to the container is configured with output_data_s3_mode. Two possible input types: - str: S3 location where data will be uploaded from a local folder named /opt/ml/processing/output/data. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. output_one, output_two) where specific parts of the data are stored locally for S3 upload. Data must be available locally in /opt/ml/processing/output/data/.

Source code in src/zenml/config/secret_reference_mixin.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references are only passed for valid fields.

    This method ensures that secret references are not passed for fields
    that explicitly prevent them or require pydantic validation.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using plain-text secrets.
        **kwargs: Arguments to initialize this object.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            or an attribute which explicitly disallows secret references
            is passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}`. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure values with secrets "
                    "here: https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if secret_utils.is_clear_text_field(field):
            raise ValueError(
                f"Passing the `{key}` attribute as a secret reference is "
                "not allowed."
            )

        requires_validation = has_validators(
            pydantic_class=self.__class__, field_name=key
        )
        if requires_validation:
            raise ValueError(
                f"Passing the attribute `{key}` as a secret reference is "
                "not allowed as additional validation is required for "
                "this attribute."
            )

    super().__init__(**kwargs)
Functions
validate_model(data: Dict[str, Any]) -> Dict[str, Any]

Check if model is configured correctly.

Parameters:

Name Type Description Default
data Dict[str, Any]

The model data.

required

Returns:

Type Description
Dict[str, Any]

The validated model data.

Raises:

Type Description
ValueError

If the model is configured incorrectly.

Source code in src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@model_validator(mode="before")
def validate_model(cls, data: Dict[str, Any]) -> Dict[str, Any]:
    """Check if model is configured correctly.

    Args:
        data: The model data.

    Returns:
        The validated model data.

    Raises:
        ValueError: If the model is configured incorrectly.
    """
    use_training_step = data.get("use_training_step", True)
    output_data_s3_uri = data.get("output_data_s3_uri", None)
    output_data_s3_mode = data.get(
        "output_data_s3_mode", DEFAULT_OUTPUT_DATA_S3_MODE
    )
    if use_training_step and (
        isinstance(output_data_s3_uri, dict)
        or (
            isinstance(output_data_s3_uri, str)
            and (output_data_s3_mode != DEFAULT_OUTPUT_DATA_S3_MODE)
        )
    ):
        raise ValueError(
            "`use_training_step=True` is not supported when `output_data_s3_uri` is a dict or "
            f"when `output_data_s3_mode` is not '{DEFAULT_OUTPUT_DATA_S3_MODE}'."
        )
    instance_type = data.get("instance_type", None)
    if instance_type is None:
        if use_training_step:
            data["instance_type"] = DEFAULT_TRAINING_INSTANCE_TYPE
        else:
            data["instance_type"] = DEFAULT_PROCESSING_INSTANCE_TYPE
    return data
Functions Modules
sagemaker_step_operator_flavor

Amazon SageMaker step operator flavor.

Classes
SagemakerStepOperatorConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseStepOperatorConfig, SagemakerStepOperatorSettings

Config for the Sagemaker step operator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_remote: bool property

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerStepOperatorFlavor

Bases: BaseStepOperatorFlavor

Flavor for the Sagemaker step operator.

Attributes
config_class: Type[SagemakerStepOperatorConfig] property

Returns SagemakerStepOperatorConfig config class.

Returns:

Type Description
Type[SagemakerStepOperatorConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerStepOperator] property

Implementation class.

Returns:

Type Description
Type[SagemakerStepOperator]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

SagemakerStepOperatorSettings(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseSettings

Settings for the Sagemaker step operator.

Attributes:

Name Type Description
experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

input_data_s3_uri Optional[Union[str, Dict[str, str]]]

S3 URI where training data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with estimator_args.input_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3.

estimator_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Estimator. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For estimator_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

environment Dict[str, str]

Environment variables to pass to the container.

Source code in src/zenml/config/secret_reference_mixin.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references are only passed for valid fields.

    This method ensures that secret references are not passed for fields
    that explicitly prevent them or require pydantic validation.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using plain-text secrets.
        **kwargs: Arguments to initialize this object.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            or an attribute which explicitly disallows secret references
            is passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}`. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure values with secrets "
                    "here: https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if secret_utils.is_clear_text_field(field):
            raise ValueError(
                f"Passing the `{key}` attribute as a secret reference is "
                "not allowed."
            )

        requires_validation = has_validators(
            pydantic_class=self.__class__, field_name=key
        )
        if requires_validation:
            raise ValueError(
                f"Passing the attribute `{key}` as a secret reference is "
                "not allowed as additional validation is required for "
                "this attribute."
            )

    super().__init__(**kwargs)
Modules

image_builders

Initialization for the AWS image builder.

Classes
AWSImageBuilder(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseImageBuilder

AWS Code Build image builder implementation.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
code_build_client: Any property

The authenticated AWS Code Build client to use for interacting with AWS services.

Returns:

Type Description
Any

The authenticated AWS Code Build client.

Raises:

Type Description
RuntimeError

If the AWS Code Build client cannot be created.

config: AWSImageBuilderConfig property

The stack component configuration.

Returns:

Type Description
AWSImageBuilderConfig

The configuration.

is_building_locally: bool property

Whether the image builder builds the images on the client machine.

Returns:

Type Description
bool

True if the image builder builds locally, False otherwise.

validator: Optional[StackValidator] property

Validates the stack for the AWS Code Build Image Builder.

The AWS Code Build Image Builder requires a container registry to push the image to and an S3 Artifact Store to upload the build context, so AWS Code Build can access it.

Returns:

Type Description
Optional[StackValidator]

Stack validator.

Functions
build(image_name: str, build_context: BuildContext, docker_build_options: Dict[str, Any], container_registry: Optional[BaseContainerRegistry] = None) -> str

Builds and pushes a Docker image.

Parameters:

Name Type Description Default
image_name str

Name of the image to build and push.

required
build_context BuildContext

The build context to use for the image.

required
docker_build_options Dict[str, Any]

Docker build options.

required
container_registry Optional[BaseContainerRegistry]

Optional container registry to push to.

None

Returns:

Type Description
str

The Docker image name with digest.

Raises:

Type Description
RuntimeError

If no container registry is passed.

RuntimeError

If the Cloud Build build fails.

Source code in src/zenml/integrations/aws/image_builders/aws_image_builder.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    def build(
        self,
        image_name: str,
        build_context: "BuildContext",
        docker_build_options: Dict[str, Any],
        container_registry: Optional["BaseContainerRegistry"] = None,
    ) -> str:
        """Builds and pushes a Docker image.

        Args:
            image_name: Name of the image to build and push.
            build_context: The build context to use for the image.
            docker_build_options: Docker build options.
            container_registry: Optional container registry to push to.

        Returns:
            The Docker image name with digest.

        Raises:
            RuntimeError: If no container registry is passed.
            RuntimeError: If the Cloud Build build fails.
        """
        if not container_registry:
            raise RuntimeError(
                "The AWS Image Builder requires a container registry to push "
                "the image to. Please provide one and try again."
            )

        logger.info("Using AWS Code Build to build image `%s`", image_name)
        cloud_build_context = self._upload_build_context(
            build_context=build_context,
            parent_path_directory_name=f"code-build-contexts/{str(self.id)}",
            archive_type=ArchiveType.ZIP,
        )

        url_parts = urlparse(cloud_build_context)
        bucket = url_parts.netloc
        object_path = url_parts.path.lstrip("/")
        logger.info(
            "Build context located in bucket `%s` and object path `%s`",
            bucket,
            object_path,
        )

        # Pass authentication credentials as environment variables, if
        # the container registry has credentials and if implicit authentication
        # is disabled
        environment_variables_override: Dict[str, str] = {}
        pre_build_commands = []
        if not self.config.implicit_container_registry_auth:
            credentials = container_registry.credentials
            if credentials:
                environment_variables_override = {
                    "CONTAINER_REGISTRY_USERNAME": credentials[0],
                    "CONTAINER_REGISTRY_PASSWORD": credentials[1],
                }
                pre_build_commands = [
                    "echo Logging in to container registry",
                    'echo "$CONTAINER_REGISTRY_PASSWORD" | docker login --username "$CONTAINER_REGISTRY_USERNAME" --password-stdin '
                    f"{container_registry.config.uri}",
                ]
        elif container_registry.flavor == AWS_CONTAINER_REGISTRY_FLAVOR:
            pre_build_commands = [
                "echo Logging in to EKS",
                f"aws ecr get-login-password --region {self.code_build_client._client_config.region_name} | docker login --username AWS --password-stdin {container_registry.config.uri}",
            ]

        # Convert the docker_build_options dictionary to a list of strings
        docker_build_args = ""
        for key, value in docker_build_options.items():
            option = f"--{key}"
            if isinstance(value, list):
                for val in value:
                    docker_build_args += f"{option} {val} "
            elif value is not None and not isinstance(value, bool):
                docker_build_args += f"{option} {value} "
            elif value is not False:
                docker_build_args += f"{option} "

        pre_build_commands_str = "\n".join(
            [f"            - {command}" for command in pre_build_commands]
        )

        # Generate and use a unique tag for the Docker image. This is easier
        # than trying to parse the image digest from the Code Build logs.
        build_id = str(uuid4())
        # Replace the tag in the image name with the unique build ID
        repo_name = image_name.split(":")[0]
        alt_image_name = f"{repo_name}:{build_id}"

        buildspec = f"""
version: 0.2
phases:
    pre_build:
        commands:
{pre_build_commands_str}
    build:
        commands:
            - echo Build started on `date`
            - echo Building the Docker image...
            - docker build -t {image_name} . {docker_build_args}
            - echo Build completed on `date`
    post_build:
        commands:
            - echo Pushing the Docker image...
            - docker push {image_name}
            - docker tag {image_name} {alt_image_name}
            - docker push {alt_image_name}
            - echo Pushed the Docker image
artifacts:
    files:
        - '**/*'
"""

        if self.config.custom_env_vars:
            environment_variables_override.update(self.config.custom_env_vars)

        environment_variables_override_list = [
            {
                "name": key,
                "value": value,
                "type": "PLAINTEXT",
            }
            for key, value in environment_variables_override.items()
        ]

        # Override the build project with the parameters needed to run a
        # docker-in-docker build, as covered here: https://docs.aws.amazon.com/codebuild/latest/userguide/sample-docker-section.html
        response = self.code_build_client.start_build(
            projectName=self.config.code_build_project,
            environmentTypeOverride="LINUX_CONTAINER",
            imageOverride=self.config.build_image,
            computeTypeOverride=self.config.compute_type,
            privilegedModeOverride=False,
            sourceTypeOverride="S3",
            sourceLocationOverride=f"{bucket}/{object_path}",
            buildspecOverride=buildspec,
            environmentVariablesOverride=environment_variables_override_list,
            # no artifacts
            artifactsOverride={"type": "NO_ARTIFACTS"},
        )

        build_arn = response["build"]["arn"]

        # Parse the AWS region, account, codebuild project and build name from the ARN
        aws_region, aws_account, build = build_arn.split(":", maxsplit=5)[3:6]
        codebuild_project = build.split("/")[1].split(":")[0]

        logs_url = f"https://{aws_region}.console.aws.amazon.com/codesuite/codebuild/{aws_account}/projects/{codebuild_project}/{build}/log"
        logger.info(
            f"Running Code Build to build the Docker image. Cloud Build logs: `{logs_url}`",
        )

        # Wait for the build to complete
        code_build_id = response["build"]["id"]
        while True:
            build_status = self.code_build_client.batch_get_builds(
                ids=[code_build_id]
            )
            build = build_status["builds"][0]
            status = build["buildStatus"]
            if status in [
                "SUCCEEDED",
                "FAILED",
                "FAULT",
                "TIMED_OUT",
                "STOPPED",
            ]:
                break
            time.sleep(10)

        if status != "SUCCEEDED":
            raise RuntimeError(
                f"The Code Build run to build the Docker image has failed. More "
                f"information can be found in the Cloud Build logs: {logs_url}."
            )

        logger.info(
            f"The Docker image has been built successfully. More information can "
            f"be found in the Cloud Build logs: `{logs_url}`."
        )

        return alt_image_name
Modules
aws_image_builder

AWS Code Build image builder implementation.

Classes
AWSImageBuilder(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseImageBuilder

AWS Code Build image builder implementation.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
code_build_client: Any property

The authenticated AWS Code Build client to use for interacting with AWS services.

Returns:

Type Description
Any

The authenticated AWS Code Build client.

Raises:

Type Description
RuntimeError

If the AWS Code Build client cannot be created.

config: AWSImageBuilderConfig property

The stack component configuration.

Returns:

Type Description
AWSImageBuilderConfig

The configuration.

is_building_locally: bool property

Whether the image builder builds the images on the client machine.

Returns:

Type Description
bool

True if the image builder builds locally, False otherwise.

validator: Optional[StackValidator] property

Validates the stack for the AWS Code Build Image Builder.

The AWS Code Build Image Builder requires a container registry to push the image to and an S3 Artifact Store to upload the build context, so AWS Code Build can access it.

Returns:

Type Description
Optional[StackValidator]

Stack validator.

Functions
build(image_name: str, build_context: BuildContext, docker_build_options: Dict[str, Any], container_registry: Optional[BaseContainerRegistry] = None) -> str

Builds and pushes a Docker image.

Parameters:

Name Type Description Default
image_name str

Name of the image to build and push.

required
build_context BuildContext

The build context to use for the image.

required
docker_build_options Dict[str, Any]

Docker build options.

required
container_registry Optional[BaseContainerRegistry]

Optional container registry to push to.

None

Returns:

Type Description
str

The Docker image name with digest.

Raises:

Type Description
RuntimeError

If no container registry is passed.

RuntimeError

If the Cloud Build build fails.

Source code in src/zenml/integrations/aws/image_builders/aws_image_builder.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    def build(
        self,
        image_name: str,
        build_context: "BuildContext",
        docker_build_options: Dict[str, Any],
        container_registry: Optional["BaseContainerRegistry"] = None,
    ) -> str:
        """Builds and pushes a Docker image.

        Args:
            image_name: Name of the image to build and push.
            build_context: The build context to use for the image.
            docker_build_options: Docker build options.
            container_registry: Optional container registry to push to.

        Returns:
            The Docker image name with digest.

        Raises:
            RuntimeError: If no container registry is passed.
            RuntimeError: If the Cloud Build build fails.
        """
        if not container_registry:
            raise RuntimeError(
                "The AWS Image Builder requires a container registry to push "
                "the image to. Please provide one and try again."
            )

        logger.info("Using AWS Code Build to build image `%s`", image_name)
        cloud_build_context = self._upload_build_context(
            build_context=build_context,
            parent_path_directory_name=f"code-build-contexts/{str(self.id)}",
            archive_type=ArchiveType.ZIP,
        )

        url_parts = urlparse(cloud_build_context)
        bucket = url_parts.netloc
        object_path = url_parts.path.lstrip("/")
        logger.info(
            "Build context located in bucket `%s` and object path `%s`",
            bucket,
            object_path,
        )

        # Pass authentication credentials as environment variables, if
        # the container registry has credentials and if implicit authentication
        # is disabled
        environment_variables_override: Dict[str, str] = {}
        pre_build_commands = []
        if not self.config.implicit_container_registry_auth:
            credentials = container_registry.credentials
            if credentials:
                environment_variables_override = {
                    "CONTAINER_REGISTRY_USERNAME": credentials[0],
                    "CONTAINER_REGISTRY_PASSWORD": credentials[1],
                }
                pre_build_commands = [
                    "echo Logging in to container registry",
                    'echo "$CONTAINER_REGISTRY_PASSWORD" | docker login --username "$CONTAINER_REGISTRY_USERNAME" --password-stdin '
                    f"{container_registry.config.uri}",
                ]
        elif container_registry.flavor == AWS_CONTAINER_REGISTRY_FLAVOR:
            pre_build_commands = [
                "echo Logging in to EKS",
                f"aws ecr get-login-password --region {self.code_build_client._client_config.region_name} | docker login --username AWS --password-stdin {container_registry.config.uri}",
            ]

        # Convert the docker_build_options dictionary to a list of strings
        docker_build_args = ""
        for key, value in docker_build_options.items():
            option = f"--{key}"
            if isinstance(value, list):
                for val in value:
                    docker_build_args += f"{option} {val} "
            elif value is not None and not isinstance(value, bool):
                docker_build_args += f"{option} {value} "
            elif value is not False:
                docker_build_args += f"{option} "

        pre_build_commands_str = "\n".join(
            [f"            - {command}" for command in pre_build_commands]
        )

        # Generate and use a unique tag for the Docker image. This is easier
        # than trying to parse the image digest from the Code Build logs.
        build_id = str(uuid4())
        # Replace the tag in the image name with the unique build ID
        repo_name = image_name.split(":")[0]
        alt_image_name = f"{repo_name}:{build_id}"

        buildspec = f"""
version: 0.2
phases:
    pre_build:
        commands:
{pre_build_commands_str}
    build:
        commands:
            - echo Build started on `date`
            - echo Building the Docker image...
            - docker build -t {image_name} . {docker_build_args}
            - echo Build completed on `date`
    post_build:
        commands:
            - echo Pushing the Docker image...
            - docker push {image_name}
            - docker tag {image_name} {alt_image_name}
            - docker push {alt_image_name}
            - echo Pushed the Docker image
artifacts:
    files:
        - '**/*'
"""

        if self.config.custom_env_vars:
            environment_variables_override.update(self.config.custom_env_vars)

        environment_variables_override_list = [
            {
                "name": key,
                "value": value,
                "type": "PLAINTEXT",
            }
            for key, value in environment_variables_override.items()
        ]

        # Override the build project with the parameters needed to run a
        # docker-in-docker build, as covered here: https://docs.aws.amazon.com/codebuild/latest/userguide/sample-docker-section.html
        response = self.code_build_client.start_build(
            projectName=self.config.code_build_project,
            environmentTypeOverride="LINUX_CONTAINER",
            imageOverride=self.config.build_image,
            computeTypeOverride=self.config.compute_type,
            privilegedModeOverride=False,
            sourceTypeOverride="S3",
            sourceLocationOverride=f"{bucket}/{object_path}",
            buildspecOverride=buildspec,
            environmentVariablesOverride=environment_variables_override_list,
            # no artifacts
            artifactsOverride={"type": "NO_ARTIFACTS"},
        )

        build_arn = response["build"]["arn"]

        # Parse the AWS region, account, codebuild project and build name from the ARN
        aws_region, aws_account, build = build_arn.split(":", maxsplit=5)[3:6]
        codebuild_project = build.split("/")[1].split(":")[0]

        logs_url = f"https://{aws_region}.console.aws.amazon.com/codesuite/codebuild/{aws_account}/projects/{codebuild_project}/{build}/log"
        logger.info(
            f"Running Code Build to build the Docker image. Cloud Build logs: `{logs_url}`",
        )

        # Wait for the build to complete
        code_build_id = response["build"]["id"]
        while True:
            build_status = self.code_build_client.batch_get_builds(
                ids=[code_build_id]
            )
            build = build_status["builds"][0]
            status = build["buildStatus"]
            if status in [
                "SUCCEEDED",
                "FAILED",
                "FAULT",
                "TIMED_OUT",
                "STOPPED",
            ]:
                break
            time.sleep(10)

        if status != "SUCCEEDED":
            raise RuntimeError(
                f"The Code Build run to build the Docker image has failed. More "
                f"information can be found in the Cloud Build logs: {logs_url}."
            )

        logger.info(
            f"The Docker image has been built successfully. More information can "
            f"be found in the Cloud Build logs: `{logs_url}`."
        )

        return alt_image_name
Functions

orchestrators

AWS Sagemaker orchestrator.

Classes
SagemakerOrchestrator(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: ContainerizedOrchestrator

Orchestrator responsible for running pipelines on Sagemaker.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: SagemakerOrchestratorConfig property

Returns the SagemakerOrchestratorConfig config.

Returns:

Type Description
SagemakerOrchestratorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property

Settings class for the Sagemaker orchestrator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property

Validates the stack.

In the remote case, checks that the stack contains a container registry, image builder and only remote components.

Returns:

Type Description
Optional[StackValidator]

A StackValidator instance.

Functions
compute_metadata(execution_arn: str, settings: SagemakerOrchestratorSettings) -> Iterator[Dict[str, MetadataType]]

Generate run metadata based on the generated Sagemaker Execution.

Parameters:

Name Type Description Default
execution_arn str

The ARN of the pipeline execution.

required
settings SagemakerOrchestratorSettings

The Sagemaker orchestrator settings.

required

Yields:

Type Description
Dict[str, MetadataType]

A dictionary of metadata related to the pipeline run.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def compute_metadata(
    self,
    execution_arn: str,
    settings: SagemakerOrchestratorSettings,
) -> Iterator[Dict[str, MetadataType]]:
    """Generate run metadata based on the generated Sagemaker Execution.

    Args:
        execution_arn: The ARN of the pipeline execution.
        settings: The Sagemaker orchestrator settings.

    Yields:
        A dictionary of metadata related to the pipeline run.
    """
    # Orchestrator Run ID
    metadata: Dict[str, MetadataType] = {
        "pipeline_execution_arn": execution_arn,
        METADATA_ORCHESTRATOR_RUN_ID: execution_arn,
    }

    # URL to the Sagemaker's pipeline view
    if orchestrator_url := self._compute_orchestrator_url(
        execution_arn=execution_arn
    ):
        metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

    # URL to the corresponding CloudWatch page
    if logs_url := self._compute_orchestrator_logs_url(
        execution_arn=execution_arn, settings=settings
    ):
        metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)

    yield metadata
fetch_status(run: PipelineRunResponse) -> ExecutionStatus

Refreshes the status of a specific pipeline run.

Parameters:

Name Type Description Default
run PipelineRunResponse

The run that was executed by this orchestrator.

required

Returns:

Type Description
ExecutionStatus

the actual status of the pipeline job.

Raises:

Type Description
AssertionError

If the run was not executed by to this orchestrator.

ValueError

If it fetches an unknown state or if we can not fetch the orchestrator run ID.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
    """Refreshes the status of a specific pipeline run.

    Args:
        run: The run that was executed by this orchestrator.

    Returns:
        the actual status of the pipeline job.

    Raises:
        AssertionError: If the run was not executed by to this orchestrator.
        ValueError: If it fetches an unknown state or if we can not fetch
            the orchestrator run ID.
    """
    # Make sure that the stack exists and is accessible
    if run.stack is None:
        raise ValueError(
            "The stack that the run was executed on is not available "
            "anymore."
        )

    # Make sure that the run belongs to this orchestrator
    assert (
        self.id
        == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
    )

    # Initialize the Sagemaker client
    session = self._get_sagemaker_session()
    sagemaker_client = session.sagemaker_client

    # Fetch the status of the _PipelineExecution
    if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
        run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
    elif run.orchestrator_run_id is not None:
        run_id = run.orchestrator_run_id
    else:
        raise ValueError(
            "Can not find the orchestrator run ID, thus can not fetch "
            "the status."
        )
    status = sagemaker_client.describe_pipeline_execution(
        PipelineExecutionArn=run_id
    )["PipelineExecutionStatus"]

    # Map the potential outputs to ZenML ExecutionStatus. Potential values:
    # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
    if status in ["Executing", "Stopping"]:
        return ExecutionStatus.RUNNING
    elif status in ["Stopped", "Failed"]:
        return ExecutionStatus.FAILED
    elif status in ["Succeeded"]:
        return ExecutionStatus.COMPLETED
    else:
        raise ValueError("Unknown status for the pipeline execution.")
generate_schedule_metadata(schedule_arn: str) -> Dict[str, str] staticmethod

Attaches metadata to the ZenML Schedules.

Parameters:

Name Type Description Default
schedule_arn str

The trigger ARNs that is generated on the AWS side.

required

Returns:

Type Description
Dict[str, str]

a dictionary containing metadata related to the schedule.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
@staticmethod
def generate_schedule_metadata(schedule_arn: str) -> Dict[str, str]:
    """Attaches metadata to the ZenML Schedules.

    Args:
        schedule_arn: The trigger ARNs that is generated on the AWS side.

    Returns:
        a dictionary containing metadata related to the schedule.
    """
    region, name = dissect_schedule_arn(schedule_arn=schedule_arn)

    return {
        "trigger_url": (
            f"https://{region}.console.aws.amazon.com/scheduler/home"
            f"?region={region}#schedules/{name}"
        ),
    }
get_orchestrator_run_id() -> str

Returns the run id of the active orchestrator run.

Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.

Returns:

Type Description
str

The orchestrator run id.

Raises:

Type Description
RuntimeError

If the run id cannot be read from the environment.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def get_orchestrator_run_id(self) -> str:
    """Returns the run id of the active orchestrator run.

    Important: This needs to be a unique ID and return the same value for
    all steps of a pipeline run.

    Returns:
        The orchestrator run id.

    Raises:
        RuntimeError: If the run id cannot be read from the environment.
    """
    try:
        return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
    except KeyError:
        raise RuntimeError(
            "Unable to read run id from environment variable "
            f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
        )
get_pipeline_run_metadata(run_id: UUID) -> Dict[str, MetadataType]

Get general component-specific metadata for a pipeline run.

Parameters:

Name Type Description Default
run_id UUID

The ID of the pipeline run.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def get_pipeline_run_metadata(
    self, run_id: UUID
) -> Dict[str, "MetadataType"]:
    """Get general component-specific metadata for a pipeline run.

    Args:
        run_id: The ID of the pipeline run.

    Returns:
        A dictionary of metadata.
    """
    execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]

    run_metadata: Dict[str, "MetadataType"] = {}

    settings = cast(
        SagemakerOrchestratorSettings,
        self.get_settings(Client().get_pipeline_run(run_id)),
    )

    for metadata in self.compute_metadata(
        execution_arn=execution_arn,
        settings=settings,
    ):
        run_metadata.update(metadata)

    return run_metadata
prepare_or_run_pipeline(deployment: PipelineDeploymentResponse, stack: Stack, environment: Dict[str, str], placeholder_run: Optional[PipelineRunResponse] = None) -> Iterator[Dict[str, MetadataType]]

Prepares or runs a pipeline on Sagemaker.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponse

The deployment to prepare or run.

required
stack Stack

The stack to run on.

required
environment Dict[str, str]

Environment variables to set in the orchestration environment.

required
placeholder_run Optional[PipelineRunResponse]

An optional placeholder run for the deployment.

None

Raises:

Type Description
RuntimeError

If there is an error creating or scheduling the pipeline.

TypeError

If the network_config passed is not compatible with the AWS SageMaker NetworkConfig class.

ValueError

If the schedule is not valid.

Yields:

Type Description
Dict[str, MetadataType]

A dictionary of metadata related to the pipeline run.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def prepare_or_run_pipeline(
    self,
    deployment: "PipelineDeploymentResponse",
    stack: "Stack",
    environment: Dict[str, str],
    placeholder_run: Optional["PipelineRunResponse"] = None,
) -> Iterator[Dict[str, MetadataType]]:
    """Prepares or runs a pipeline on Sagemaker.

    Args:
        deployment: The deployment to prepare or run.
        stack: The stack to run on.
        environment: Environment variables to set in the orchestration
            environment.
        placeholder_run: An optional placeholder run for the deployment.

    Raises:
        RuntimeError: If there is an error creating or scheduling the
            pipeline.
        TypeError: If the network_config passed is not compatible with the
            AWS SageMaker NetworkConfig class.
        ValueError: If the schedule is not valid.

    Yields:
        A dictionary of metadata related to the pipeline run.
    """
    # sagemaker requires pipelineName to use alphanum and hyphens only
    unsanitized_orchestrator_run_name = get_orchestrator_run_name(
        pipeline_name=deployment.pipeline_configuration.name
    )
    # replace all non-alphanum and non-hyphens with hyphens
    orchestrator_run_name = re.sub(
        r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
    )

    session = self._get_sagemaker_session()

    # Sagemaker does not allow environment variables longer than 256
    # characters to be passed to Processor steps. If an environment variable
    # is longer than 256 characters, we split it into multiple environment
    # variables (chunks) and re-construct it on the other side using the
    # custom entrypoint configuration.
    split_environment_variables(
        size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
        env=environment,
    )

    sagemaker_steps = []
    for step_name, step in deployment.step_configurations.items():
        image = self.get_image(deployment=deployment, step_name=step_name)
        command = SagemakerEntrypointConfiguration.get_entrypoint_command()
        arguments = (
            SagemakerEntrypointConfiguration.get_entrypoint_arguments(
                step_name=step_name, deployment_id=deployment.id
            )
        )
        entrypoint = command + arguments

        step_settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(step)
        )

        environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
            ExecutionVariables.PIPELINE_EXECUTION_ARN
        )

        if step_settings.environment:
            step_environment = step_settings.environment.copy()
            # Sagemaker does not allow environment variables longer than 256
            # characters to be passed to Processor steps. If an environment variable
            # is longer than 256 characters, we split it into multiple environment
            # variables (chunks) and re-construct it on the other side using the
            # custom entrypoint configuration.
            split_environment_variables(
                size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
                env=step_environment,
            )
            environment.update(step_environment)

        use_training_step = (
            step_settings.use_training_step
            if step_settings.use_training_step is not None
            else (
                self.config.use_training_step
                if self.config.use_training_step is not None
                else True
            )
        )

        # Retrieve Executor arguments provided in the Step settings.
        if use_training_step:
            args_for_step_executor = step_settings.estimator_args or {}
            args_for_step_executor.setdefault(
                "volume_size", step_settings.volume_size_in_gb
            )
            args_for_step_executor.setdefault(
                "max_run", step_settings.max_runtime_in_seconds
            )
        else:
            args_for_step_executor = step_settings.processor_args or {}
            args_for_step_executor.setdefault(
                "volume_size_in_gb", step_settings.volume_size_in_gb
            )
            args_for_step_executor.setdefault(
                "max_runtime_in_seconds",
                step_settings.max_runtime_in_seconds,
            )

        # Set default values from configured orchestrator Component to
        # arguments to be used when they are not present in processor_args.
        args_for_step_executor.setdefault(
            "role",
            step_settings.execution_role or self.config.execution_role,
        )

        tags = step_settings.tags
        args_for_step_executor.setdefault(
            "tags",
            (
                [
                    {"Key": key, "Value": value}
                    for key, value in tags.items()
                ]
                if tags
                else None
            ),
        )

        args_for_step_executor.setdefault(
            "instance_type", step_settings.instance_type
        )

        # Set values that cannot be overwritten
        args_for_step_executor["image_uri"] = image
        args_for_step_executor["instance_count"] = 1
        args_for_step_executor["sagemaker_session"] = session
        args_for_step_executor["base_job_name"] = orchestrator_run_name

        # Convert network_config to sagemaker.network.NetworkConfig if
        # present
        network_config = args_for_step_executor.get("network_config")

        if network_config and isinstance(network_config, dict):
            try:
                args_for_step_executor["network_config"] = NetworkConfig(
                    **network_config
                )
            except TypeError:
                # If the network_config passed is not compatible with the
                # NetworkConfig class, raise a more informative error.
                raise TypeError(
                    "Expected a sagemaker.network.NetworkConfig "
                    "compatible object for the network_config argument, "
                    "but the network_config processor argument is invalid."
                    "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
                    "for more information about the NetworkConfig class."
                )

        # Construct S3 inputs to container for step
        inputs = None

        if step_settings.input_data_s3_uri is None:
            pass
        elif isinstance(step_settings.input_data_s3_uri, str):
            inputs = [
                ProcessingInput(
                    source=step_settings.input_data_s3_uri,
                    destination="/opt/ml/processing/input/data",
                    s3_input_mode=step_settings.input_data_s3_mode,
                )
            ]
        elif isinstance(step_settings.input_data_s3_uri, dict):
            inputs = []
            for channel, s3_uri in step_settings.input_data_s3_uri.items():
                inputs.append(
                    ProcessingInput(
                        source=s3_uri,
                        destination=f"/opt/ml/processing/input/data/{channel}",
                        s3_input_mode=step_settings.input_data_s3_mode,
                    )
                )

        # Construct S3 outputs from container for step
        outputs = None
        output_path = None

        if step_settings.output_data_s3_uri is None:
            pass
        elif isinstance(step_settings.output_data_s3_uri, str):
            if use_training_step:
                output_path = step_settings.output_data_s3_uri
            else:
                outputs = [
                    ProcessingOutput(
                        source="/opt/ml/processing/output/data",
                        destination=step_settings.output_data_s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                ]
        elif isinstance(step_settings.output_data_s3_uri, dict):
            outputs = []
            for (
                channel,
                s3_uri,
            ) in step_settings.output_data_s3_uri.items():
                outputs.append(
                    ProcessingOutput(
                        source=f"/opt/ml/processing/output/data/{channel}",
                        destination=s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                )

        # Convert environment to a dict of strings
        environment = {
            key: str(value)
            if not isinstance(value, ExecutionVariable)
            else value
            for key, value in environment.items()
        }

        if use_training_step:
            # Create Estimator and TrainingStep
            estimator = sagemaker.estimator.Estimator(
                keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
                output_path=output_path,
                environment=environment,
                container_entry_point=entrypoint,
                **args_for_step_executor,
            )
            sagemaker_step = TrainingStep(
                name=step_name,
                depends_on=step.spec.upstream_steps,
                inputs=inputs,
                estimator=estimator,
            )
        else:
            # Create Processor and ProcessingStep
            processor = sagemaker.processing.Processor(
                entrypoint=entrypoint,
                env=environment,
                **args_for_step_executor,
            )

            sagemaker_step = ProcessingStep(
                name=step_name,
                processor=processor,
                depends_on=step.spec.upstream_steps,
                inputs=inputs,
                outputs=outputs,
            )

        sagemaker_steps.append(sagemaker_step)

    # Create the pipeline
    pipeline = Pipeline(
        name=orchestrator_run_name,
        steps=sagemaker_steps,
        sagemaker_session=session,
    )

    settings = cast(
        SagemakerOrchestratorSettings, self.get_settings(deployment)
    )

    pipeline.create(
        role_arn=self.config.execution_role,
        tags=[
            {"Key": key, "Value": value}
            for key, value in settings.pipeline_tags.items()
        ]
        if settings.pipeline_tags
        else None,
    )

    # Handle scheduling if specified
    if deployment.schedule:
        if settings.synchronous:
            logger.warning(
                "The 'synchronous' setting is ignored for scheduled "
                "pipelines since they run independently of the "
                "deployment process."
            )

        schedule_name = orchestrator_run_name
        next_execution = None
        start_date = (
            to_utc_timezone(deployment.schedule.start_time)
            if deployment.schedule.start_time
            else None
        )

        # Create PipelineSchedule based on schedule type
        if deployment.schedule.cron_expression:
            cron_exp = self._validate_cron_expression(
                deployment.schedule.cron_expression
            )
            schedule = PipelineSchedule(
                name=schedule_name,
                cron=cron_exp,
                start_date=start_date,
                enabled=True,
            )
        elif deployment.schedule.interval_second:
            # This is necessary because SageMaker's PipelineSchedule rate
            # expressions require minutes as the minimum time unit.
            # Even if a user specifies an interval of less than 60 seconds,
            # it will be rounded up to 1 minute.
            minutes = max(
                1,
                int(
                    deployment.schedule.interval_second.total_seconds()
                    / 60
                ),
            )
            schedule = PipelineSchedule(
                name=schedule_name,
                rate=(minutes, "minutes"),
                start_date=start_date,
                enabled=True,
            )
            next_execution = (
                deployment.schedule.start_time or utc_now_tz_aware()
            ) + deployment.schedule.interval_second
        else:
            # One-time schedule
            execution_time = (
                deployment.schedule.run_once_start_time
                or deployment.schedule.start_time
            )
            if not execution_time:
                raise ValueError(
                    "A start time must be specified for one-time "
                    "schedule execution"
                )
            schedule = PipelineSchedule(
                name=schedule_name,
                at=to_utc_timezone(execution_time),
                enabled=True,
            )
            next_execution = execution_time

        # Get the current role ARN if not explicitly configured
        if self.config.scheduler_role is None:
            logger.info(
                "No scheduler_role configured. Trying to extract it from "
                "the client side authentication."
            )
            sts = session.boto_session.client("sts")
            try:
                scheduler_role_arn = sts.get_caller_identity()["Arn"]
                # If this is a user ARN, try to get the role ARN
                if ":user/" in scheduler_role_arn:
                    logger.warning(
                        f"Using IAM user credentials "
                        f"({scheduler_role_arn}). For production "
                        "environments, it's recommended to use IAM roles "
                        "instead."
                    )
                # If this is an assumed role, extract the role ARN
                elif ":assumed-role/" in scheduler_role_arn:
                    # Convert assumed-role ARN format to role ARN format
                    # From: arn:aws:sts::123456789012:assumed-role/role-name/session-name
                    # To: arn:aws:iam::123456789012:role/role-name
                    scheduler_role_arn = re.sub(
                        r"arn:aws:sts::(\d+):assumed-role/([^/]+)/.*",
                        r"arn:aws:iam::\1:role/\2",
                        scheduler_role_arn,
                    )
                elif ":role/" not in scheduler_role_arn:
                    raise RuntimeError(
                        f"Unexpected credential type "
                        f"({scheduler_role_arn}). Please use IAM "
                        f"roles for SageMaker pipeline scheduling."
                    )
                else:
                    raise RuntimeError(
                        "The ARN of the caller identity "
                        f"`{scheduler_role_arn}` does not "
                        "include a user or a proper role."
                    )
            except Exception:
                raise RuntimeError(
                    "Failed to get current role ARN. This means the "
                    "your client side credentials that you are "
                    "is not configured correctly to schedule sagemaker "
                    "pipelines. For more information, please check:"
                    "https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules"
                )
        else:
            scheduler_role_arn = self.config.scheduler_role

        # Attach schedule to pipeline
        triggers = pipeline.put_triggers(
            triggers=[schedule],
            role_arn=scheduler_role_arn,
        )
        logger.info(f"The schedule ARN is: {triggers[0]}")

        try:
            from zenml.models import RunMetadataResource

            schedule_metadata = self.generate_schedule_metadata(
                schedule_arn=triggers[0]
            )

            Client().create_run_metadata(
                metadata=schedule_metadata,  # type: ignore[arg-type]
                resources=[
                    RunMetadataResource(
                        id=deployment.schedule.id,
                        type=MetadataResourceTypes.SCHEDULE,
                    )
                ],
            )
        except Exception as e:
            logger.debug(
                "There was an error attaching metadata to the "
                f"schedule: {e}"
            )

        logger.info(
            f"Successfully scheduled pipeline with name: {schedule_name}\n"
            + (
                f"First execution will occur at: "
                f"{next_execution.strftime('%Y-%m-%d %H:%M:%S UTC')}"
                if next_execution
                else f"Using cron expression: "
                f"{deployment.schedule.cron_expression}"
            )
            + (
                f" (and every {minutes} minutes after)"
                if deployment.schedule.interval_second
                else ""
            )
        )
        logger.info(
            "\n\nIn order to cancel the schedule, you can use execute "
            "the following command:\n"
        )
        logger.info(
            f"`aws scheduler delete-schedule --name {schedule_name}`"
        )
    else:
        # Execute the pipeline immediately if no schedule is specified
        execution = pipeline.start()
        logger.warning(
            "Steps can take 5-15 minutes to start running "
            "when using the Sagemaker Orchestrator."
        )

        # Yield metadata based on the generated execution object
        yield from self.compute_metadata(
            execution_arn=execution.arn, settings=settings
        )

        # mainly for testing purposes, we wait for the pipeline to finish
        if settings.synchronous:
            logger.info(
                "Executing synchronously. Waiting for pipeline to "
                "finish... \n"
                "At this point you can `Ctrl-C` out without cancelling the "
                "execution."
            )
            try:
                execution.wait(
                    delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
                )
                logger.info("Pipeline completed successfully.")
            except WaiterError:
                raise RuntimeError(
                    "Timed out while waiting for pipeline execution to "
                    "finish. For long-running pipelines we recommend "
                    "configuring your orchestrator for asynchronous "
                    "execution. The following command does this for you: \n"
                    f"`zenml orchestrator update {self.name} "
                    f"--synchronous=False`"
                )
Modules
sagemaker_orchestrator

Implementation of the SageMaker orchestrator.

Classes
SagemakerOrchestrator(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: ContainerizedOrchestrator

Orchestrator responsible for running pipelines on Sagemaker.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: SagemakerOrchestratorConfig property

Returns the SagemakerOrchestratorConfig config.

Returns:

Type Description
SagemakerOrchestratorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property

Settings class for the Sagemaker orchestrator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property

Validates the stack.

In the remote case, checks that the stack contains a container registry, image builder and only remote components.

Returns:

Type Description
Optional[StackValidator]

A StackValidator instance.

Functions
compute_metadata(execution_arn: str, settings: SagemakerOrchestratorSettings) -> Iterator[Dict[str, MetadataType]]

Generate run metadata based on the generated Sagemaker Execution.

Parameters:

Name Type Description Default
execution_arn str

The ARN of the pipeline execution.

required
settings SagemakerOrchestratorSettings

The Sagemaker orchestrator settings.

required

Yields:

Type Description
Dict[str, MetadataType]

A dictionary of metadata related to the pipeline run.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def compute_metadata(
    self,
    execution_arn: str,
    settings: SagemakerOrchestratorSettings,
) -> Iterator[Dict[str, MetadataType]]:
    """Generate run metadata based on the generated Sagemaker Execution.

    Args:
        execution_arn: The ARN of the pipeline execution.
        settings: The Sagemaker orchestrator settings.

    Yields:
        A dictionary of metadata related to the pipeline run.
    """
    # Orchestrator Run ID
    metadata: Dict[str, MetadataType] = {
        "pipeline_execution_arn": execution_arn,
        METADATA_ORCHESTRATOR_RUN_ID: execution_arn,
    }

    # URL to the Sagemaker's pipeline view
    if orchestrator_url := self._compute_orchestrator_url(
        execution_arn=execution_arn
    ):
        metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

    # URL to the corresponding CloudWatch page
    if logs_url := self._compute_orchestrator_logs_url(
        execution_arn=execution_arn, settings=settings
    ):
        metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)

    yield metadata
fetch_status(run: PipelineRunResponse) -> ExecutionStatus

Refreshes the status of a specific pipeline run.

Parameters:

Name Type Description Default
run PipelineRunResponse

The run that was executed by this orchestrator.

required

Returns:

Type Description
ExecutionStatus

the actual status of the pipeline job.

Raises:

Type Description
AssertionError

If the run was not executed by to this orchestrator.

ValueError

If it fetches an unknown state or if we can not fetch the orchestrator run ID.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
    """Refreshes the status of a specific pipeline run.

    Args:
        run: The run that was executed by this orchestrator.

    Returns:
        the actual status of the pipeline job.

    Raises:
        AssertionError: If the run was not executed by to this orchestrator.
        ValueError: If it fetches an unknown state or if we can not fetch
            the orchestrator run ID.
    """
    # Make sure that the stack exists and is accessible
    if run.stack is None:
        raise ValueError(
            "The stack that the run was executed on is not available "
            "anymore."
        )

    # Make sure that the run belongs to this orchestrator
    assert (
        self.id
        == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
    )

    # Initialize the Sagemaker client
    session = self._get_sagemaker_session()
    sagemaker_client = session.sagemaker_client

    # Fetch the status of the _PipelineExecution
    if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
        run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
    elif run.orchestrator_run_id is not None:
        run_id = run.orchestrator_run_id
    else:
        raise ValueError(
            "Can not find the orchestrator run ID, thus can not fetch "
            "the status."
        )
    status = sagemaker_client.describe_pipeline_execution(
        PipelineExecutionArn=run_id
    )["PipelineExecutionStatus"]

    # Map the potential outputs to ZenML ExecutionStatus. Potential values:
    # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
    if status in ["Executing", "Stopping"]:
        return ExecutionStatus.RUNNING
    elif status in ["Stopped", "Failed"]:
        return ExecutionStatus.FAILED
    elif status in ["Succeeded"]:
        return ExecutionStatus.COMPLETED
    else:
        raise ValueError("Unknown status for the pipeline execution.")
generate_schedule_metadata(schedule_arn: str) -> Dict[str, str] staticmethod

Attaches metadata to the ZenML Schedules.

Parameters:

Name Type Description Default
schedule_arn str

The trigger ARNs that is generated on the AWS side.

required

Returns:

Type Description
Dict[str, str]

a dictionary containing metadata related to the schedule.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
@staticmethod
def generate_schedule_metadata(schedule_arn: str) -> Dict[str, str]:
    """Attaches metadata to the ZenML Schedules.

    Args:
        schedule_arn: The trigger ARNs that is generated on the AWS side.

    Returns:
        a dictionary containing metadata related to the schedule.
    """
    region, name = dissect_schedule_arn(schedule_arn=schedule_arn)

    return {
        "trigger_url": (
            f"https://{region}.console.aws.amazon.com/scheduler/home"
            f"?region={region}#schedules/{name}"
        ),
    }
get_orchestrator_run_id() -> str

Returns the run id of the active orchestrator run.

Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.

Returns:

Type Description
str

The orchestrator run id.

Raises:

Type Description
RuntimeError

If the run id cannot be read from the environment.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def get_orchestrator_run_id(self) -> str:
    """Returns the run id of the active orchestrator run.

    Important: This needs to be a unique ID and return the same value for
    all steps of a pipeline run.

    Returns:
        The orchestrator run id.

    Raises:
        RuntimeError: If the run id cannot be read from the environment.
    """
    try:
        return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
    except KeyError:
        raise RuntimeError(
            "Unable to read run id from environment variable "
            f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
        )
get_pipeline_run_metadata(run_id: UUID) -> Dict[str, MetadataType]

Get general component-specific metadata for a pipeline run.

Parameters:

Name Type Description Default
run_id UUID

The ID of the pipeline run.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def get_pipeline_run_metadata(
    self, run_id: UUID
) -> Dict[str, "MetadataType"]:
    """Get general component-specific metadata for a pipeline run.

    Args:
        run_id: The ID of the pipeline run.

    Returns:
        A dictionary of metadata.
    """
    execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]

    run_metadata: Dict[str, "MetadataType"] = {}

    settings = cast(
        SagemakerOrchestratorSettings,
        self.get_settings(Client().get_pipeline_run(run_id)),
    )

    for metadata in self.compute_metadata(
        execution_arn=execution_arn,
        settings=settings,
    ):
        run_metadata.update(metadata)

    return run_metadata
prepare_or_run_pipeline(deployment: PipelineDeploymentResponse, stack: Stack, environment: Dict[str, str], placeholder_run: Optional[PipelineRunResponse] = None) -> Iterator[Dict[str, MetadataType]]

Prepares or runs a pipeline on Sagemaker.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponse

The deployment to prepare or run.

required
stack Stack

The stack to run on.

required
environment Dict[str, str]

Environment variables to set in the orchestration environment.

required
placeholder_run Optional[PipelineRunResponse]

An optional placeholder run for the deployment.

None

Raises:

Type Description
RuntimeError

If there is an error creating or scheduling the pipeline.

TypeError

If the network_config passed is not compatible with the AWS SageMaker NetworkConfig class.

ValueError

If the schedule is not valid.

Yields:

Type Description
Dict[str, MetadataType]

A dictionary of metadata related to the pipeline run.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def prepare_or_run_pipeline(
    self,
    deployment: "PipelineDeploymentResponse",
    stack: "Stack",
    environment: Dict[str, str],
    placeholder_run: Optional["PipelineRunResponse"] = None,
) -> Iterator[Dict[str, MetadataType]]:
    """Prepares or runs a pipeline on Sagemaker.

    Args:
        deployment: The deployment to prepare or run.
        stack: The stack to run on.
        environment: Environment variables to set in the orchestration
            environment.
        placeholder_run: An optional placeholder run for the deployment.

    Raises:
        RuntimeError: If there is an error creating or scheduling the
            pipeline.
        TypeError: If the network_config passed is not compatible with the
            AWS SageMaker NetworkConfig class.
        ValueError: If the schedule is not valid.

    Yields:
        A dictionary of metadata related to the pipeline run.
    """
    # sagemaker requires pipelineName to use alphanum and hyphens only
    unsanitized_orchestrator_run_name = get_orchestrator_run_name(
        pipeline_name=deployment.pipeline_configuration.name
    )
    # replace all non-alphanum and non-hyphens with hyphens
    orchestrator_run_name = re.sub(
        r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
    )

    session = self._get_sagemaker_session()

    # Sagemaker does not allow environment variables longer than 256
    # characters to be passed to Processor steps. If an environment variable
    # is longer than 256 characters, we split it into multiple environment
    # variables (chunks) and re-construct it on the other side using the
    # custom entrypoint configuration.
    split_environment_variables(
        size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
        env=environment,
    )

    sagemaker_steps = []
    for step_name, step in deployment.step_configurations.items():
        image = self.get_image(deployment=deployment, step_name=step_name)
        command = SagemakerEntrypointConfiguration.get_entrypoint_command()
        arguments = (
            SagemakerEntrypointConfiguration.get_entrypoint_arguments(
                step_name=step_name, deployment_id=deployment.id
            )
        )
        entrypoint = command + arguments

        step_settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(step)
        )

        environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
            ExecutionVariables.PIPELINE_EXECUTION_ARN
        )

        if step_settings.environment:
            step_environment = step_settings.environment.copy()
            # Sagemaker does not allow environment variables longer than 256
            # characters to be passed to Processor steps. If an environment variable
            # is longer than 256 characters, we split it into multiple environment
            # variables (chunks) and re-construct it on the other side using the
            # custom entrypoint configuration.
            split_environment_variables(
                size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
                env=step_environment,
            )
            environment.update(step_environment)

        use_training_step = (
            step_settings.use_training_step
            if step_settings.use_training_step is not None
            else (
                self.config.use_training_step
                if self.config.use_training_step is not None
                else True
            )
        )

        # Retrieve Executor arguments provided in the Step settings.
        if use_training_step:
            args_for_step_executor = step_settings.estimator_args or {}
            args_for_step_executor.setdefault(
                "volume_size", step_settings.volume_size_in_gb
            )
            args_for_step_executor.setdefault(
                "max_run", step_settings.max_runtime_in_seconds
            )
        else:
            args_for_step_executor = step_settings.processor_args or {}
            args_for_step_executor.setdefault(
                "volume_size_in_gb", step_settings.volume_size_in_gb
            )
            args_for_step_executor.setdefault(
                "max_runtime_in_seconds",
                step_settings.max_runtime_in_seconds,
            )

        # Set default values from configured orchestrator Component to
        # arguments to be used when they are not present in processor_args.
        args_for_step_executor.setdefault(
            "role",
            step_settings.execution_role or self.config.execution_role,
        )

        tags = step_settings.tags
        args_for_step_executor.setdefault(
            "tags",
            (
                [
                    {"Key": key, "Value": value}
                    for key, value in tags.items()
                ]
                if tags
                else None
            ),
        )

        args_for_step_executor.setdefault(
            "instance_type", step_settings.instance_type
        )

        # Set values that cannot be overwritten
        args_for_step_executor["image_uri"] = image
        args_for_step_executor["instance_count"] = 1
        args_for_step_executor["sagemaker_session"] = session
        args_for_step_executor["base_job_name"] = orchestrator_run_name

        # Convert network_config to sagemaker.network.NetworkConfig if
        # present
        network_config = args_for_step_executor.get("network_config")

        if network_config and isinstance(network_config, dict):
            try:
                args_for_step_executor["network_config"] = NetworkConfig(
                    **network_config
                )
            except TypeError:
                # If the network_config passed is not compatible with the
                # NetworkConfig class, raise a more informative error.
                raise TypeError(
                    "Expected a sagemaker.network.NetworkConfig "
                    "compatible object for the network_config argument, "
                    "but the network_config processor argument is invalid."
                    "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
                    "for more information about the NetworkConfig class."
                )

        # Construct S3 inputs to container for step
        inputs = None

        if step_settings.input_data_s3_uri is None:
            pass
        elif isinstance(step_settings.input_data_s3_uri, str):
            inputs = [
                ProcessingInput(
                    source=step_settings.input_data_s3_uri,
                    destination="/opt/ml/processing/input/data",
                    s3_input_mode=step_settings.input_data_s3_mode,
                )
            ]
        elif isinstance(step_settings.input_data_s3_uri, dict):
            inputs = []
            for channel, s3_uri in step_settings.input_data_s3_uri.items():
                inputs.append(
                    ProcessingInput(
                        source=s3_uri,
                        destination=f"/opt/ml/processing/input/data/{channel}",
                        s3_input_mode=step_settings.input_data_s3_mode,
                    )
                )

        # Construct S3 outputs from container for step
        outputs = None
        output_path = None

        if step_settings.output_data_s3_uri is None:
            pass
        elif isinstance(step_settings.output_data_s3_uri, str):
            if use_training_step:
                output_path = step_settings.output_data_s3_uri
            else:
                outputs = [
                    ProcessingOutput(
                        source="/opt/ml/processing/output/data",
                        destination=step_settings.output_data_s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                ]
        elif isinstance(step_settings.output_data_s3_uri, dict):
            outputs = []
            for (
                channel,
                s3_uri,
            ) in step_settings.output_data_s3_uri.items():
                outputs.append(
                    ProcessingOutput(
                        source=f"/opt/ml/processing/output/data/{channel}",
                        destination=s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                )

        # Convert environment to a dict of strings
        environment = {
            key: str(value)
            if not isinstance(value, ExecutionVariable)
            else value
            for key, value in environment.items()
        }

        if use_training_step:
            # Create Estimator and TrainingStep
            estimator = sagemaker.estimator.Estimator(
                keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
                output_path=output_path,
                environment=environment,
                container_entry_point=entrypoint,
                **args_for_step_executor,
            )
            sagemaker_step = TrainingStep(
                name=step_name,
                depends_on=step.spec.upstream_steps,
                inputs=inputs,
                estimator=estimator,
            )
        else:
            # Create Processor and ProcessingStep
            processor = sagemaker.processing.Processor(
                entrypoint=entrypoint,
                env=environment,
                **args_for_step_executor,
            )

            sagemaker_step = ProcessingStep(
                name=step_name,
                processor=processor,
                depends_on=step.spec.upstream_steps,
                inputs=inputs,
                outputs=outputs,
            )

        sagemaker_steps.append(sagemaker_step)

    # Create the pipeline
    pipeline = Pipeline(
        name=orchestrator_run_name,
        steps=sagemaker_steps,
        sagemaker_session=session,
    )

    settings = cast(
        SagemakerOrchestratorSettings, self.get_settings(deployment)
    )

    pipeline.create(
        role_arn=self.config.execution_role,
        tags=[
            {"Key": key, "Value": value}
            for key, value in settings.pipeline_tags.items()
        ]
        if settings.pipeline_tags
        else None,
    )

    # Handle scheduling if specified
    if deployment.schedule:
        if settings.synchronous:
            logger.warning(
                "The 'synchronous' setting is ignored for scheduled "
                "pipelines since they run independently of the "
                "deployment process."
            )

        schedule_name = orchestrator_run_name
        next_execution = None
        start_date = (
            to_utc_timezone(deployment.schedule.start_time)
            if deployment.schedule.start_time
            else None
        )

        # Create PipelineSchedule based on schedule type
        if deployment.schedule.cron_expression:
            cron_exp = self._validate_cron_expression(
                deployment.schedule.cron_expression
            )
            schedule = PipelineSchedule(
                name=schedule_name,
                cron=cron_exp,
                start_date=start_date,
                enabled=True,
            )
        elif deployment.schedule.interval_second:
            # This is necessary because SageMaker's PipelineSchedule rate
            # expressions require minutes as the minimum time unit.
            # Even if a user specifies an interval of less than 60 seconds,
            # it will be rounded up to 1 minute.
            minutes = max(
                1,
                int(
                    deployment.schedule.interval_second.total_seconds()
                    / 60
                ),
            )
            schedule = PipelineSchedule(
                name=schedule_name,
                rate=(minutes, "minutes"),
                start_date=start_date,
                enabled=True,
            )
            next_execution = (
                deployment.schedule.start_time or utc_now_tz_aware()
            ) + deployment.schedule.interval_second
        else:
            # One-time schedule
            execution_time = (
                deployment.schedule.run_once_start_time
                or deployment.schedule.start_time
            )
            if not execution_time:
                raise ValueError(
                    "A start time must be specified for one-time "
                    "schedule execution"
                )
            schedule = PipelineSchedule(
                name=schedule_name,
                at=to_utc_timezone(execution_time),
                enabled=True,
            )
            next_execution = execution_time

        # Get the current role ARN if not explicitly configured
        if self.config.scheduler_role is None:
            logger.info(
                "No scheduler_role configured. Trying to extract it from "
                "the client side authentication."
            )
            sts = session.boto_session.client("sts")
            try:
                scheduler_role_arn = sts.get_caller_identity()["Arn"]
                # If this is a user ARN, try to get the role ARN
                if ":user/" in scheduler_role_arn:
                    logger.warning(
                        f"Using IAM user credentials "
                        f"({scheduler_role_arn}). For production "
                        "environments, it's recommended to use IAM roles "
                        "instead."
                    )
                # If this is an assumed role, extract the role ARN
                elif ":assumed-role/" in scheduler_role_arn:
                    # Convert assumed-role ARN format to role ARN format
                    # From: arn:aws:sts::123456789012:assumed-role/role-name/session-name
                    # To: arn:aws:iam::123456789012:role/role-name
                    scheduler_role_arn = re.sub(
                        r"arn:aws:sts::(\d+):assumed-role/([^/]+)/.*",
                        r"arn:aws:iam::\1:role/\2",
                        scheduler_role_arn,
                    )
                elif ":role/" not in scheduler_role_arn:
                    raise RuntimeError(
                        f"Unexpected credential type "
                        f"({scheduler_role_arn}). Please use IAM "
                        f"roles for SageMaker pipeline scheduling."
                    )
                else:
                    raise RuntimeError(
                        "The ARN of the caller identity "
                        f"`{scheduler_role_arn}` does not "
                        "include a user or a proper role."
                    )
            except Exception:
                raise RuntimeError(
                    "Failed to get current role ARN. This means the "
                    "your client side credentials that you are "
                    "is not configured correctly to schedule sagemaker "
                    "pipelines. For more information, please check:"
                    "https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules"
                )
        else:
            scheduler_role_arn = self.config.scheduler_role

        # Attach schedule to pipeline
        triggers = pipeline.put_triggers(
            triggers=[schedule],
            role_arn=scheduler_role_arn,
        )
        logger.info(f"The schedule ARN is: {triggers[0]}")

        try:
            from zenml.models import RunMetadataResource

            schedule_metadata = self.generate_schedule_metadata(
                schedule_arn=triggers[0]
            )

            Client().create_run_metadata(
                metadata=schedule_metadata,  # type: ignore[arg-type]
                resources=[
                    RunMetadataResource(
                        id=deployment.schedule.id,
                        type=MetadataResourceTypes.SCHEDULE,
                    )
                ],
            )
        except Exception as e:
            logger.debug(
                "There was an error attaching metadata to the "
                f"schedule: {e}"
            )

        logger.info(
            f"Successfully scheduled pipeline with name: {schedule_name}\n"
            + (
                f"First execution will occur at: "
                f"{next_execution.strftime('%Y-%m-%d %H:%M:%S UTC')}"
                if next_execution
                else f"Using cron expression: "
                f"{deployment.schedule.cron_expression}"
            )
            + (
                f" (and every {minutes} minutes after)"
                if deployment.schedule.interval_second
                else ""
            )
        )
        logger.info(
            "\n\nIn order to cancel the schedule, you can use execute "
            "the following command:\n"
        )
        logger.info(
            f"`aws scheduler delete-schedule --name {schedule_name}`"
        )
    else:
        # Execute the pipeline immediately if no schedule is specified
        execution = pipeline.start()
        logger.warning(
            "Steps can take 5-15 minutes to start running "
            "when using the Sagemaker Orchestrator."
        )

        # Yield metadata based on the generated execution object
        yield from self.compute_metadata(
            execution_arn=execution.arn, settings=settings
        )

        # mainly for testing purposes, we wait for the pipeline to finish
        if settings.synchronous:
            logger.info(
                "Executing synchronously. Waiting for pipeline to "
                "finish... \n"
                "At this point you can `Ctrl-C` out without cancelling the "
                "execution."
            )
            try:
                execution.wait(
                    delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
                )
                logger.info("Pipeline completed successfully.")
            except WaiterError:
                raise RuntimeError(
                    "Timed out while waiting for pipeline execution to "
                    "finish. For long-running pipelines we recommend "
                    "configuring your orchestrator for asynchronous "
                    "execution. The following command does this for you: \n"
                    f"`zenml orchestrator update {self.name} "
                    f"--synchronous=False`"
                )
Functions
dissect_pipeline_execution_arn(pipeline_execution_arn: str) -> Tuple[Optional[str], Optional[str], Optional[str]]

Extract region name, pipeline name, and execution id from the ARN.

Parameters:

Name Type Description Default
pipeline_execution_arn str

the pipeline execution ARN

required

Returns:

Type Description
Tuple[Optional[str], Optional[str], Optional[str]]

Region Name, Pipeline Name, Execution ID in order

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def dissect_pipeline_execution_arn(
    pipeline_execution_arn: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
    """Extract region name, pipeline name, and execution id from the ARN.

    Args:
        pipeline_execution_arn: the pipeline execution ARN

    Returns:
        Region Name, Pipeline Name, Execution ID in order
    """
    # Extract region_name
    region_match = re.search(r"sagemaker:(.*?):", pipeline_execution_arn)
    region_name = region_match.group(1) if region_match else None

    # Extract pipeline_name
    pipeline_match = re.search(
        r"pipeline/(.*?)/execution", pipeline_execution_arn
    )
    pipeline_name = pipeline_match.group(1) if pipeline_match else None

    # Extract execution_id
    execution_match = re.search(r"execution/(.*)", pipeline_execution_arn)
    execution_id = execution_match.group(1) if execution_match else None

    return region_name, pipeline_name, execution_id
dissect_schedule_arn(schedule_arn: str) -> Tuple[Optional[str], Optional[str]]

Extracts the region and the name from an EventBridge schedule ARN.

Parameters:

Name Type Description Default
schedule_arn str

The ARN of the EventBridge schedule.

required

Returns:

Type Description
Tuple[Optional[str], Optional[str]]

Region Name, Schedule Name (including the group name)

Raises:

Type Description
ValueError

If the input is not a properly formatted ARN.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def dissect_schedule_arn(
    schedule_arn: str,
) -> Tuple[Optional[str], Optional[str]]:
    """Extracts the region and the name from an EventBridge schedule ARN.

    Args:
        schedule_arn: The ARN of the EventBridge schedule.

    Returns:
        Region Name, Schedule Name (including the group name)

    Raises:
        ValueError: If the input is not a properly formatted ARN.
    """
    # Split the ARN into parts
    arn_parts = schedule_arn.split(":")

    # Validate ARN structure
    if len(arn_parts) < 6 or not arn_parts[5].startswith("schedule/"):
        raise ValueError("Invalid EventBridge schedule ARN format.")

    # Extract the region
    region = arn_parts[3]

    # Extract the group name and schedule name
    name = arn_parts[5].split("schedule/")[1]

    return region, name
sagemaker_orchestrator_entrypoint_config

Entrypoint configuration for ZenML Sagemaker pipeline steps.

Classes
SagemakerEntrypointConfiguration(arguments: List[str])

Bases: StepEntrypointConfiguration

Entrypoint configuration for ZenML Sagemaker pipeline steps.

The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 256 characters allowed for Sagemaker Processor steps from their individual components.

Source code in src/zenml/entrypoints/base_entrypoint_configuration.py
60
61
62
63
64
65
66
def __init__(self, arguments: List[str]):
    """Initializes the entrypoint configuration.

    Args:
        arguments: Command line arguments to configure this object.
    """
    self.entrypoint_args = self._parse_arguments(arguments)
Functions
run() -> None

Runs the step.

Source code in src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py
32
33
34
35
36
37
38
39
def run(self) -> None:
    """Runs the step."""
    # Reconstruct the environment variables that exceed the maximum length
    # of 256 characters from their individual chunks
    reconstruct_environment_variables()

    # Run the step
    super().run()
Functions

service_connectors

AWS Service Connector.

Classes
AWSServiceConnector(**kwargs: Any)

Bases: ServiceConnector

AWS service connector.

Source code in src/zenml/service_connectors/service_connector.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def __init__(self, **kwargs: Any) -> None:
    """Initialize a new service connector instance.

    Args:
        kwargs: Additional keyword arguments to pass to the base class
            constructor.
    """
    super().__init__(**kwargs)

    # Convert the resource ID to its canonical form. For resource types
    # that don't support multiple instances:
    # - if a resource ID is not provided, we use the default resource ID for
    # the resource type
    # - if a resource ID is provided, we verify that it matches the default
    # resource ID for the resource type
    if self.resource_type:
        try:
            self.resource_id = self._validate_resource_id(
                self.resource_type, self.resource_id
            )
        except AuthorizationException as e:
            error = (
                f"Authorization error validating resource ID "
                f"{self.resource_id} for resource type "
                f"{self.resource_type}: {e}"
            )
            # Log an exception if debug logging is enabled
            if logger.isEnabledFor(logging.DEBUG):
                logger.exception(error)
            else:
                logger.warning(error)

            self.resource_id = None
Attributes
account_id: str property

Get the AWS account ID.

Returns:

Type Description
str

The AWS account ID.

Raises:

Type Description
AuthorizationException

If the AWS account ID could not be determined.

Functions
get_boto3_session(auth_method: str, resource_type: Optional[str] = None, resource_id: Optional[str] = None) -> Tuple[boto3.Session, Optional[datetime.datetime]]

Get a boto3 session for the specified resource.

Parameters:

Name Type Description Default
auth_method str

The authentication method to use.

required
resource_type Optional[str]

The resource type to get a boto3 session for.

None
resource_id Optional[str]

The resource ID to get a boto3 session for.

None

Returns:

Type Description
Session

A boto3 session for the specified resource and its expiration

Optional[datetime]

timestamp, if applicable.

Source code in src/zenml/integrations/aws/service_connectors/aws_service_connector.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def get_boto3_session(
    self,
    auth_method: str,
    resource_type: Optional[str] = None,
    resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
    """Get a boto3 session for the specified resource.

    Args:
        auth_method: The authentication method to use.
        resource_type: The resource type to get a boto3 session for.
        resource_id: The resource ID to get a boto3 session for.

    Returns:
        A boto3 session for the specified resource and its expiration
        timestamp, if applicable.
    """
    # We maintain a cache of all sessions to avoid re-authenticating
    # multiple times for the same resource
    key = (auth_method, resource_type, resource_id)
    if key in self._session_cache:
        session, expires_at = self._session_cache[key]
        if expires_at is None:
            return session, None

        # Refresh expired sessions
        now = utc_now_tz_aware()
        expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
        # check if the token expires in the near future
        if expires_at > now + datetime.timedelta(
            minutes=BOTO3_SESSION_EXPIRATION_BUFFER
        ):
            return session, expires_at

    logger.debug(
        f"Creating boto3 session for auth method '{auth_method}', "
        f"resource type '{resource_type}' and resource ID "
        f"'{resource_id}'..."
    )
    session, expires_at = self._authenticate(
        auth_method, resource_type, resource_id
    )
    self._session_cache[key] = (session, expires_at)
    return session, expires_at
get_ecr_client() -> BaseClient

Get an ECR client.

Raises:

Type Description
ValueError

If the service connector is not able to instantiate an ECR client.

Returns:

Type Description
BaseClient

An ECR client.

Source code in src/zenml/integrations/aws/service_connectors/aws_service_connector.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def get_ecr_client(self) -> BaseClient:
    """Get an ECR client.

    Raises:
        ValueError: If the service connector is not able to instantiate an
            ECR client.

    Returns:
        An ECR client.
    """
    if self.resource_type and self.resource_type not in {
        AWS_RESOURCE_TYPE,
        DOCKER_REGISTRY_RESOURCE_TYPE,
    }:
        raise ValueError(
            f"Unable to instantiate ECR client for a connector that is "
            f"configured to provide access to a '{self.resource_type}' "
            "resource type."
        )

    session, _ = self.get_boto3_session(
        auth_method=self.auth_method,
        resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
        resource_id=self.config.region,
    )
    return session.client(
        "ecr",
        region_name=self.config.region,
        endpoint_url=self.config.endpoint_url,
    )
Modules
aws_service_connector

AWS Service Connector.

The AWS Service Connector implements various authentication methods for AWS services:

  • Explicit AWS secret key (access key, secret key)
  • Explicit AWS STS tokens (access key, secret key, session token)
  • IAM roles (i.e. generating temporary STS tokens on the fly by assuming an IAM role)
  • IAM user federation tokens
  • STS Session tokens
Classes
AWSAuthenticationMethods

Bases: StrEnum

AWS Authentication methods.

AWSBaseConfig

Bases: AuthenticationConfig

AWS base configuration.

AWSImplicitConfig

Bases: AWSBaseConfig, AWSSessionPolicy

AWS implicit configuration.

AWSSecretKey

Bases: AuthenticationConfig

AWS secret key credentials.

AWSSecretKeyConfig

Bases: AWSBaseConfig, AWSSecretKey

AWS secret key authentication configuration.

AWSServiceConnector(**kwargs: Any)

Bases: ServiceConnector

AWS service connector.

Source code in src/zenml/service_connectors/service_connector.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def __init__(self, **kwargs: Any) -> None:
    """Initialize a new service connector instance.

    Args:
        kwargs: Additional keyword arguments to pass to the base class
            constructor.
    """
    super().__init__(**kwargs)

    # Convert the resource ID to its canonical form. For resource types
    # that don't support multiple instances:
    # - if a resource ID is not provided, we use the default resource ID for
    # the resource type
    # - if a resource ID is provided, we verify that it matches the default
    # resource ID for the resource type
    if self.resource_type:
        try:
            self.resource_id = self._validate_resource_id(
                self.resource_type, self.resource_id
            )
        except AuthorizationException as e:
            error = (
                f"Authorization error validating resource ID "
                f"{self.resource_id} for resource type "
                f"{self.resource_type}: {e}"
            )
            # Log an exception if debug logging is enabled
            if logger.isEnabledFor(logging.DEBUG):
                logger.exception(error)
            else:
                logger.warning(error)

            self.resource_id = None
Attributes
account_id: str property

Get the AWS account ID.

Returns:

Type Description
str

The AWS account ID.

Raises:

Type Description
AuthorizationException

If the AWS account ID could not be determined.

Functions
get_boto3_session(auth_method: str, resource_type: Optional[str] = None, resource_id: Optional[str] = None) -> Tuple[boto3.Session, Optional[datetime.datetime]]

Get a boto3 session for the specified resource.

Parameters:

Name Type Description Default
auth_method str

The authentication method to use.

required
resource_type Optional[str]

The resource type to get a boto3 session for.

None
resource_id Optional[str]

The resource ID to get a boto3 session for.

None

Returns:

Type Description
Session

A boto3 session for the specified resource and its expiration

Optional[datetime]

timestamp, if applicable.

Source code in src/zenml/integrations/aws/service_connectors/aws_service_connector.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def get_boto3_session(
    self,
    auth_method: str,
    resource_type: Optional[str] = None,
    resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
    """Get a boto3 session for the specified resource.

    Args:
        auth_method: The authentication method to use.
        resource_type: The resource type to get a boto3 session for.
        resource_id: The resource ID to get a boto3 session for.

    Returns:
        A boto3 session for the specified resource and its expiration
        timestamp, if applicable.
    """
    # We maintain a cache of all sessions to avoid re-authenticating
    # multiple times for the same resource
    key = (auth_method, resource_type, resource_id)
    if key in self._session_cache:
        session, expires_at = self._session_cache[key]
        if expires_at is None:
            return session, None

        # Refresh expired sessions
        now = utc_now_tz_aware()
        expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
        # check if the token expires in the near future
        if expires_at > now + datetime.timedelta(
            minutes=BOTO3_SESSION_EXPIRATION_BUFFER
        ):
            return session, expires_at

    logger.debug(
        f"Creating boto3 session for auth method '{auth_method}', "
        f"resource type '{resource_type}' and resource ID "
        f"'{resource_id}'..."
    )
    session, expires_at = self._authenticate(
        auth_method, resource_type, resource_id
    )
    self._session_cache[key] = (session, expires_at)
    return session, expires_at
get_ecr_client() -> BaseClient

Get an ECR client.

Raises:

Type Description
ValueError

If the service connector is not able to instantiate an ECR client.

Returns:

Type Description
BaseClient

An ECR client.

Source code in src/zenml/integrations/aws/service_connectors/aws_service_connector.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def get_ecr_client(self) -> BaseClient:
    """Get an ECR client.

    Raises:
        ValueError: If the service connector is not able to instantiate an
            ECR client.

    Returns:
        An ECR client.
    """
    if self.resource_type and self.resource_type not in {
        AWS_RESOURCE_TYPE,
        DOCKER_REGISTRY_RESOURCE_TYPE,
    }:
        raise ValueError(
            f"Unable to instantiate ECR client for a connector that is "
            f"configured to provide access to a '{self.resource_type}' "
            "resource type."
        )

    session, _ = self.get_boto3_session(
        auth_method=self.auth_method,
        resource_type=DOCKER_REGISTRY_RESOURCE_TYPE,
        resource_id=self.config.region,
    )
    return session.client(
        "ecr",
        region_name=self.config.region,
        endpoint_url=self.config.endpoint_url,
    )
AWSSessionPolicy

Bases: AuthenticationConfig

AWS session IAM policy configuration.

FederationTokenAuthenticationConfig

Bases: AWSSecretKeyConfig, AWSSessionPolicy

AWS federation token authentication config.

IAMRoleAuthenticationConfig

Bases: AWSSecretKeyConfig, AWSSessionPolicy

AWS IAM authentication config.

STSToken

Bases: AWSSecretKey

AWS STS token.

STSTokenConfig

Bases: AWSBaseConfig, STSToken

AWS STS token authentication configuration.

SessionTokenAuthenticationConfig

Bases: AWSSecretKeyConfig

AWS session token authentication config.

Functions

step_operators

Initialization of the Sagemaker Step Operator.

Classes
SagemakerStepOperator(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseStepOperator

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: SagemakerStepOperatorConfig property

Returns the SagemakerStepOperatorConfig config.

Returns:

Type Description
SagemakerStepOperatorConfig

The configuration.

entrypoint_config_class: Type[StepOperatorEntrypointConfiguration] property

Returns the entrypoint configuration class for this step operator.

Returns:

Type Description
Type[StepOperatorEntrypointConfiguration]

The entrypoint configuration class for this step operator.

settings_class: Optional[Type[BaseSettings]] property

Settings class for the SageMaker step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property

Validates the stack.

Returns:

Type Description
Optional[StackValidator]

A validator that checks that the stack contains a remote container

Optional[StackValidator]

registry and a remote artifact store.

Functions
get_docker_builds(deployment: PipelineDeploymentBase) -> List[BuildConfiguration]

Gets the Docker builds required for the component.

Parameters:

Name Type Description Default
deployment PipelineDeploymentBase

The pipeline deployment for which to get the builds.

required

Returns:

Type Description
List[BuildConfiguration]

The required Docker builds.

Source code in src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_docker_builds(
    self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
    """Gets the Docker builds required for the component.

    Args:
        deployment: The pipeline deployment for which to get the builds.

    Returns:
        The required Docker builds.
    """
    builds = []
    for step_name, step in deployment.step_configurations.items():
        if step.config.step_operator == self.name:
            build = BuildConfiguration(
                key=SAGEMAKER_DOCKER_IMAGE_KEY,
                settings=step.config.docker_settings,
                step_name=step_name,
                entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
            )
            builds.append(build)

    return builds
launch(info: StepRunInfo, entrypoint_command: List[str], environment: Dict[str, str]) -> None

Launches a step on SageMaker.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
environment Dict[str, str]

Environment variables to set in the step operator environment.

required

Raises:

Type Description
RuntimeError

If the connector returns an object that is not a boto3.Session.

Source code in src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Launches a step on SageMaker.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
        environment: Environment variables to set in the step operator
            environment.

    Raises:
        RuntimeError: If the connector returns an object that is not a
            `boto3.Session`.
    """
    if not info.config.resource_settings.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

    if settings.environment:
        environment.update(settings.environment)

    # Sagemaker does not allow environment variables longer than 512
    # characters to be passed to Estimator steps. If an environment variable
    # is longer than 512 characters, we split it into multiple environment
    # variables (chunks) and re-construct it on the other side using the
    # custom entrypoint configuration.
    split_environment_variables(
        env=environment,
        size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
    )

    image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
    environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

    # Get and default fill SageMaker estimator arguments for full ZenML support
    estimator_args = settings.estimator_args

    # Get authenticated session
    # Option 1: Service connector
    boto_session: boto3.Session
    if connector := self.get_connector():
        boto_session = connector.connect()
        if not isinstance(boto_session, boto3.Session):
            raise RuntimeError(
                f"Expected to receive a `boto3.Session` object from the "
                f"linked connector, but got type `{type(boto_session)}`."
            )
    # Option 2: Implicit configuration
    else:
        boto_session = boto3.Session()

    session = sagemaker.Session(
        boto_session=boto_session, default_bucket=self.config.bucket
    )

    estimator_args.setdefault(
        "instance_type", settings.instance_type or "ml.m5.large"
    )

    # Convert environment to a dict of strings
    environment = {key: str(value) for key, value in environment.items()}

    estimator_args["environment"] = environment
    estimator_args["instance_count"] = 1
    estimator_args["sagemaker_session"] = session

    # Create Estimator
    estimator = sagemaker.estimator.Estimator(
        image_name, self.config.role, **estimator_args
    )

    # SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
    step_name = Client().get_run_step(info.step_run_id).name
    training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
    suffix = random_str(4)
    unique_training_job_name = f"{training_job_name}-{suffix}"

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_training_job_name = unique_training_job_name.replace(
        "_", "-"
    )

    # Construct training input object, if necessary
    inputs = None

    if isinstance(settings.input_data_s3_uri, str):
        inputs = sagemaker.inputs.TrainingInput(
            s3_data=settings.input_data_s3_uri
        )
    elif isinstance(settings.input_data_s3_uri, dict):
        inputs = {}
        for channel, s3_uri in settings.input_data_s3_uri.items():
            inputs[channel] = sagemaker.inputs.TrainingInput(
                s3_data=s3_uri
            )

    experiment_config = {}
    if settings.experiment_name:
        experiment_config = {
            "ExperimentName": settings.experiment_name,
            "TrialName": sanitized_training_job_name,
        }
    info.force_write_logs()
    estimator.fit(
        wait=True,
        inputs=inputs,
        experiment_config=experiment_config,
        job_name=sanitized_training_job_name,
    )
Modules
sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

Classes
SagemakerStepOperator(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseStepOperator

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: SagemakerStepOperatorConfig property

Returns the SagemakerStepOperatorConfig config.

Returns:

Type Description
SagemakerStepOperatorConfig

The configuration.

entrypoint_config_class: Type[StepOperatorEntrypointConfiguration] property

Returns the entrypoint configuration class for this step operator.

Returns:

Type Description
Type[StepOperatorEntrypointConfiguration]

The entrypoint configuration class for this step operator.

settings_class: Optional[Type[BaseSettings]] property

Settings class for the SageMaker step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property

Validates the stack.

Returns:

Type Description
Optional[StackValidator]

A validator that checks that the stack contains a remote container

Optional[StackValidator]

registry and a remote artifact store.

Functions
get_docker_builds(deployment: PipelineDeploymentBase) -> List[BuildConfiguration]

Gets the Docker builds required for the component.

Parameters:

Name Type Description Default
deployment PipelineDeploymentBase

The pipeline deployment for which to get the builds.

required

Returns:

Type Description
List[BuildConfiguration]

The required Docker builds.

Source code in src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_docker_builds(
    self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
    """Gets the Docker builds required for the component.

    Args:
        deployment: The pipeline deployment for which to get the builds.

    Returns:
        The required Docker builds.
    """
    builds = []
    for step_name, step in deployment.step_configurations.items():
        if step.config.step_operator == self.name:
            build = BuildConfiguration(
                key=SAGEMAKER_DOCKER_IMAGE_KEY,
                settings=step.config.docker_settings,
                step_name=step_name,
                entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
            )
            builds.append(build)

    return builds
launch(info: StepRunInfo, entrypoint_command: List[str], environment: Dict[str, str]) -> None

Launches a step on SageMaker.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
environment Dict[str, str]

Environment variables to set in the step operator environment.

required

Raises:

Type Description
RuntimeError

If the connector returns an object that is not a boto3.Session.

Source code in src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Launches a step on SageMaker.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
        environment: Environment variables to set in the step operator
            environment.

    Raises:
        RuntimeError: If the connector returns an object that is not a
            `boto3.Session`.
    """
    if not info.config.resource_settings.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

    if settings.environment:
        environment.update(settings.environment)

    # Sagemaker does not allow environment variables longer than 512
    # characters to be passed to Estimator steps. If an environment variable
    # is longer than 512 characters, we split it into multiple environment
    # variables (chunks) and re-construct it on the other side using the
    # custom entrypoint configuration.
    split_environment_variables(
        env=environment,
        size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
    )

    image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
    environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

    # Get and default fill SageMaker estimator arguments for full ZenML support
    estimator_args = settings.estimator_args

    # Get authenticated session
    # Option 1: Service connector
    boto_session: boto3.Session
    if connector := self.get_connector():
        boto_session = connector.connect()
        if not isinstance(boto_session, boto3.Session):
            raise RuntimeError(
                f"Expected to receive a `boto3.Session` object from the "
                f"linked connector, but got type `{type(boto_session)}`."
            )
    # Option 2: Implicit configuration
    else:
        boto_session = boto3.Session()

    session = sagemaker.Session(
        boto_session=boto_session, default_bucket=self.config.bucket
    )

    estimator_args.setdefault(
        "instance_type", settings.instance_type or "ml.m5.large"
    )

    # Convert environment to a dict of strings
    environment = {key: str(value) for key, value in environment.items()}

    estimator_args["environment"] = environment
    estimator_args["instance_count"] = 1
    estimator_args["sagemaker_session"] = session

    # Create Estimator
    estimator = sagemaker.estimator.Estimator(
        image_name, self.config.role, **estimator_args
    )

    # SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
    step_name = Client().get_run_step(info.step_run_id).name
    training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
    suffix = random_str(4)
    unique_training_job_name = f"{training_job_name}-{suffix}"

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_training_job_name = unique_training_job_name.replace(
        "_", "-"
    )

    # Construct training input object, if necessary
    inputs = None

    if isinstance(settings.input_data_s3_uri, str):
        inputs = sagemaker.inputs.TrainingInput(
            s3_data=settings.input_data_s3_uri
        )
    elif isinstance(settings.input_data_s3_uri, dict):
        inputs = {}
        for channel, s3_uri in settings.input_data_s3_uri.items():
            inputs[channel] = sagemaker.inputs.TrainingInput(
                s3_data=s3_uri
            )

    experiment_config = {}
    if settings.experiment_name:
        experiment_config = {
            "ExperimentName": settings.experiment_name,
            "TrialName": sanitized_training_job_name,
        }
    info.force_write_logs()
    estimator.fit(
        wait=True,
        inputs=inputs,
        experiment_config=experiment_config,
        job_name=sanitized_training_job_name,
    )
Functions
sagemaker_step_operator_entrypoint_config

Entrypoint configuration for ZenML Sagemaker step operator.

Classes
SagemakerEntrypointConfiguration(arguments: List[str])

Bases: StepOperatorEntrypointConfiguration

Entrypoint configuration for ZenML Sagemaker step operator.

The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 512 characters allowed for Sagemaker Estimator steps from their individual components.

Source code in src/zenml/entrypoints/base_entrypoint_configuration.py
60
61
62
63
64
65
66
def __init__(self, arguments: List[str]):
    """Initializes the entrypoint configuration.

    Args:
        arguments: Command line arguments to configure this object.
    """
    self.entrypoint_args = self._parse_arguments(arguments)
Functions
run() -> None

Runs the step.

Source code in src/zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py
32
33
34
35
36
37
38
39
def run(self) -> None:
    """Runs the step."""
    # Reconstruct the environment variables that exceed the maximum length
    # of 512 characters from their individual chunks
    reconstruct_environment_variables()

    # Run the step
    super().run()
Functions