Skip to content

Baseten

zenml.integrations.baseten

Baseten integration for running steps as Baseten training jobs.

Attributes

BASETEN = 'baseten' module-attribute

BASETEN_STEP_OPERATOR_FLAVOR = 'baseten' module-attribute

Classes

BasetenIntegration

Bases: Integration

Definition of the Baseten integration for ZenML.

Methods:
flavors() -> List[Type[Flavor]] classmethod

Declare the stack component flavors for the Baseten integration.

Returns:

Type Description
List[Type[Flavor]]

List of new stack component flavors.

Source code in src/zenml/integrations/baseten/__init__.py
31
32
33
34
35
36
37
38
39
40
41
42
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the Baseten integration.

    Returns:
        List of new stack component flavors.
    """
    from zenml.integrations.baseten.flavors import (
        BasetenStepOperatorFlavor,
    )

    return [BasetenStepOperatorFlavor]

Flavor

Class for ZenML Flavors.

Attributes
config_class: Type[StackComponentConfig] abstractmethod property

Returns StackComponentConfig config class.

Returns:

Type Description
Type[StackComponentConfig]

The config class.

config_schema: Dict[str, Any] property

The config schema for a flavor.

Returns:

Type Description
Dict[str, Any]

The config schema.

display_name: Optional[str] property

The display name of the flavor.

By default, converts the technical name to a human-readable format. For example, "vm_kubernetes" becomes "VM Kubernetes". Flavors can override this to provide custom display names.

Returns:

Type Description
Optional[str]

The display name of the flavor.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[StackComponent] abstractmethod property

Implementation class for this flavor.

Returns:

Type Description
Type[StackComponent]

The implementation class for this flavor.

logo_url: Optional[str] property

A url to represent the flavor in the dashboard.

Returns:

Type Description
Optional[str]

The flavor logo.

name: str abstractmethod property

The flavor name.

Returns:

Type Description
str

The flavor name.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

type: StackComponentType abstractmethod property

The stack component type.

Returns:

Type Description
StackComponentType

The stack component type.

Methods:
from_model(flavor_model: FlavorResponse) -> Flavor classmethod

Loads a flavor from a model.

Parameters:

Name Type Description Default
flavor_model FlavorResponse

The model to load from.

required

Raises:

Type Description
CustomFlavorImportError

If the custom flavor can't be imported.

ImportError

If the flavor can't be imported.

Returns:

Type Description
Flavor

The loaded flavor.

Source code in src/zenml/stack/flavor.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@classmethod
def from_model(cls, flavor_model: FlavorResponse) -> "Flavor":
    """Loads a flavor from a model.

    Args:
        flavor_model: The model to load from.

    Raises:
        CustomFlavorImportError: If the custom flavor can't be imported.
        ImportError: If the flavor can't be imported.

    Returns:
        The loaded flavor.
    """
    try:
        _, flavor = validate_flavor_source(
            source=flavor_model.source,
            component_type=flavor_model.type,
            validate_component_classes=False,
        )
    except (TypeError, ValueError) as err:
        if flavor_model.is_custom:
            flavor_module, _, _ = flavor_model.source.rpartition(".")
            expected_file_path = os.path.join(
                source_utils.get_source_root(),
                flavor_module.replace(".", os.path.sep),
            )
            raise CustomFlavorImportError(
                f"Couldn't import custom flavor {flavor_model.name}: "
                f"{err}. Make sure the custom flavor class "
                f"`{flavor_model.source}` is importable. If it is part of "
                "a library, make sure it is installed. If "
                "it is a local code file, make sure it exists at "
                f"`{expected_file_path}.py`."
            ) from err
        else:
            raise ImportError(
                f"Couldn't import flavor {flavor_model.name}: {err}"
            ) from err
    return flavor
generate_default_docs_url() -> str

Generate the doc urls for all inbuilt and integration flavors.

Note that this method is not going to be useful for custom flavors, which do not have any docs in the main zenml docs.

Returns:

Type Description
str

The complete url to the zenml documentation

Source code in src/zenml/stack/flavor.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def generate_default_docs_url(self) -> str:
    """Generate the doc urls for all inbuilt and integration flavors.

    Note that this method is not going to be useful for custom flavors,
    which do not have any docs in the main zenml docs.

    Returns:
        The complete url to the zenml documentation
    """
    component_type = self.type.plural.replace("_", "-")
    name = self.name.replace("_", "-")

    base = "https://docs.zenml.io"
    return f"{base}/stack-components/{component_type}/{name}"
generate_default_sdk_docs_url() -> str

Generate SDK docs url for a flavor.

Returns:

Type Description
str

The complete url to the zenml SDK docs

Source code in src/zenml/stack/flavor.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def generate_default_sdk_docs_url(self) -> str:
    """Generate SDK docs url for a flavor.

    Returns:
        The complete url to the zenml SDK docs
    """
    from zenml import __version__

    base = f"https://sdkdocs.zenml.io/{__version__}"

    component_type = self.type.plural

    if "zenml.integrations" in self.__module__:
        # Get integration name out of module path which will look something
        #  like this "zenml.integrations.<integration>....
        integration = self.__module__.split(
            "zenml.integrations.", maxsplit=1
        )[1].split(".")[0]

        # Get the config class name to point to the specific class
        config_class_name = self.config_class.__name__

        return (
            f"{base}/integration_code_docs"
            f"/integrations-{integration}"
            f"#zenml.integrations.{integration}.flavors.{config_class_name}"
        )

    else:
        return (
            f"{base}/core_code_docs/core-{component_type}/"
            f"#{self.__module__}"
        )
to_model(integration: Optional[str] = None, is_custom: bool = True) -> FlavorRequest

Converts a flavor to a model.

Parameters:

Name Type Description Default
integration Optional[str]

The integration to use for the model.

None
is_custom bool

Whether the flavor is a custom flavor.

True

Returns:

Type Description
FlavorRequest

The model.

Source code in src/zenml/stack/flavor.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def to_model(
    self,
    integration: Optional[str] = None,
    is_custom: bool = True,
) -> FlavorRequest:
    """Converts a flavor to a model.

    Args:
        integration: The integration to use for the model.
        is_custom: Whether the flavor is a custom flavor.

    Returns:
        The model.
    """
    connector_requirements = self.service_connector_requirements
    connector_type = (
        connector_requirements.connector_type
        if connector_requirements
        else None
    )
    resource_type = (
        connector_requirements.resource_type
        if connector_requirements
        else None
    )
    resource_id_attr = (
        connector_requirements.resource_id_attr
        if connector_requirements
        else None
    )

    model = FlavorRequest(
        name=self.name,
        display_name=self.display_name,
        type=self.type,
        source=source_utils.resolve(self.__class__).import_path,
        config_schema=self.config_schema,
        connector_type=connector_type,
        connector_resource_type=resource_type,
        connector_resource_id_attr=resource_id_attr,
        integration=integration,
        logo_url=self.logo_url,
        docs_url=self.docs_url,
        sdk_docs_url=self.sdk_docs_url,
        is_custom=is_custom,
    )
    return model

Integration

Base class for integration in ZenML.

Methods:
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
136
137
138
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for requirement in cls.get_requirements():
        parsed_requirement = Requirement(requirement)

        if not requirement_installed(parsed_requirement):
            logger.debug(
                "Requirement '%s' for integration '%s' is not installed "
                "or installed with the wrong version.",
                requirement,
                cls.NAME,
            )
            return False

        dependencies = get_dependencies(parsed_requirement)

        for dependency in dependencies:
            if not requirement_installed(dependency):
                logger.debug(
                    "Requirement '%s' for integration '%s' is not "
                    "installed or installed with the wrong version.",
                    dependency,
                    cls.NAME,
                )
                return False

    logger.debug(
        f"Integration '{cls.NAME}' is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
140
141
142
143
144
145
146
147
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@classmethod
def get_requirements(
    cls,
    target_os: Optional[str] = None,
    python_version: Optional[str] = None,
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret

Modules

baseten_api

Minimal REST client for the Baseten Training API.

Classes
BasetenApiClient(api_key: str, base_url: str = BASETEN_API_BASE_URL)

Thin client for the Baseten Training REST API.

Initialize the client.

Parameters:

Name Type Description Default
api_key str

Baseten API key used for authentication.

required
base_url str

Base URL of the Baseten API.

BASETEN_API_BASE_URL
Source code in src/zenml/integrations/baseten/baseten_api.py
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self, api_key: str, base_url: str = BASETEN_API_BASE_URL
) -> None:
    """Initialize the client.

    Args:
        api_key: Baseten API key used for authentication.
        base_url: Base URL of the Baseten API.
    """
    self._api_key = api_key
    self._base_url = base_url.rstrip("/")
    self._session = _build_session()
Methods:
get_job_status(project_id: str, job_id: str) -> Optional[str]

Get the current status of a training job.

Parameters:

Name Type Description Default
project_id str

The Baseten training project id.

required
job_id str

The Baseten training job id.

required

Returns:

Type Description
Optional[str]

The Baseten job status string, or None if the job no longer

Optional[str]

exists (HTTP 404). Raises requests.HTTPError if the request

Optional[str]

fails for any other reason.

Source code in src/zenml/integrations/baseten/baseten_api.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def get_job_status(self, project_id: str, job_id: str) -> Optional[str]:
    """Get the current status of a training job.

    Args:
        project_id: The Baseten training project id.
        job_id: The Baseten training job id.

    Returns:
        The Baseten job status string, or None if the job no longer
        exists (HTTP 404). Raises ``requests.HTTPError`` if the request
        fails for any other reason.
    """
    response = self._session.get(
        self._job_url(project_id, job_id),
        headers=self._headers,
        timeout=_REQUEST_TIMEOUT,
    )
    if response.status_code == 404:
        return None
    response.raise_for_status()
    job = response.json().get("training_job", {})
    return cast(Optional[str], job.get("current_status"))
stop_job(project_id: str, job_id: str) -> None

Stop a running training job.

Raises requests.HTTPError if the stop request fails.

Parameters:

Name Type Description Default
project_id str

The Baseten training project id.

required
job_id str

The Baseten training job id.

required
Source code in src/zenml/integrations/baseten/baseten_api.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def stop_job(self, project_id: str, job_id: str) -> None:
    """Stop a running training job.

    Raises ``requests.HTTPError`` if the stop request fails.

    Args:
        project_id: The Baseten training project id.
        job_id: The Baseten training job id.
    """
    response = self._session.post(
        f"{self._job_url(project_id, job_id)}/stop",
        headers=self._headers,
        json={},
        timeout=_REQUEST_TIMEOUT,
    )
    response.raise_for_status()
Functions:

flavors

Baseten integration flavors.

Classes
BasetenStepOperatorConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseStepOperatorConfig

Configuration for the Baseten step operator.

Source code in src/zenml/stack/stack_component.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/deploying-zenml/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_remote: bool property

Whether the step operator runs steps remotely.

Returns:

Type Description
bool

True

BasetenStepOperatorFlavor

Bases: BaseStepOperatorFlavor

Baseten step operator flavor.

Attributes
config_class: Type[BasetenStepOperatorConfig] property

Returns the config class for this flavor.

Returns:

Type Description
Type[BasetenStepOperatorConfig]

The config class.

docs_url: Optional[str] property

A URL to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[BasetenStepOperator] property

Implementation class for this flavor.

Returns:

Type Description
Type[BasetenStepOperator]

The implementation class.

logo_url: str property

A URL to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A URL to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

BasetenStepOperatorSettings(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseSettings

Settings for the Baseten step operator.

Source code in src/zenml/config/secret_reference_mixin.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references are only passed for valid fields.

    This method ensures that secret references are not passed for fields
    that explicitly prevent them or require pydantic validation.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using plain-text secrets.
        **kwargs: Arguments to initialize this object.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            or an attribute which explicitly disallows secret references
            is passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}`. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure values with secrets "
                    "here: https://docs.zenml.io/deploying-zenml/deploying-zenml/secret-management"
                )
            continue

        if secret_utils.is_clear_text_field(field):
            raise ValueError(
                f"Passing the `{key}` attribute as a secret reference is "
                "not allowed."
            )

        requires_validation = has_validators(
            pydantic_class=self.__class__, field_name=key
        )
        if requires_validation:
            raise ValueError(
                f"Passing the attribute `{key}` as a secret reference is "
                "not allowed as additional validation is required for "
                "this attribute."
            )

    super().__init__(**kwargs)
Modules
baseten_step_operator_flavor

Baseten step operator flavor.

Classes
BasetenStepOperatorConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseStepOperatorConfig

Configuration for the Baseten step operator.

Source code in src/zenml/stack/stack_component.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/deploying-zenml/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_remote: bool property

Whether the step operator runs steps remotely.

Returns:

Type Description
bool

True

BasetenStepOperatorFlavor

Bases: BaseStepOperatorFlavor

Baseten step operator flavor.

Attributes
config_class: Type[BasetenStepOperatorConfig] property

Returns the config class for this flavor.

Returns:

Type Description
Type[BasetenStepOperatorConfig]

The config class.

docs_url: Optional[str] property

A URL to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[BasetenStepOperator] property

Implementation class for this flavor.

Returns:

Type Description
Type[BasetenStepOperator]

The implementation class.

logo_url: str property

A URL to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A URL to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

BasetenStepOperatorSettings(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseSettings

Settings for the Baseten step operator.

Source code in src/zenml/config/secret_reference_mixin.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references are only passed for valid fields.

    This method ensures that secret references are not passed for fields
    that explicitly prevent them or require pydantic validation.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using plain-text secrets.
        **kwargs: Arguments to initialize this object.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            or an attribute which explicitly disallows secret references
            is passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}`. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure values with secrets "
                    "here: https://docs.zenml.io/deploying-zenml/deploying-zenml/secret-management"
                )
            continue

        if secret_utils.is_clear_text_field(field):
            raise ValueError(
                f"Passing the `{key}` attribute as a secret reference is "
                "not allowed."
            )

        requires_validation = has_validators(
            pydantic_class=self.__class__, field_name=key
        )
        if requires_validation:
            raise ValueError(
                f"Passing the attribute `{key}` as a secret reference is "
                "not allowed as additional validation is required for "
                "this attribute."
            )

    super().__init__(**kwargs)

step_operators

Baseten step operator.

Classes
BasetenStepOperator(*args: Any, **kwargs: Any)

Bases: BaseStepOperator

Step operator that runs a step as a Baseten training job.

Initialize the step operator.

Parameters:

Name Type Description Default
*args Any

Positional arguments forwarded to the base class.

()
**kwargs Any

Keyword arguments forwarded to the base class.

{}
Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
137
138
139
140
141
142
143
144
145
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the step operator.

    Args:
        *args: Positional arguments forwarded to the base class.
        **kwargs: Keyword arguments forwarded to the base class.
    """
    super().__init__(*args, **kwargs)
    self._api: Optional["BasetenApiClient"] = None
Attributes
api: BasetenApiClient property

Lazily constructed Baseten REST API client.

Returns:

Type Description
BasetenApiClient

The Baseten REST API client.

config: BasetenStepOperatorConfig property

Get the Baseten step operator configuration.

Returns:

Type Description
BasetenStepOperatorConfig

The Baseten step operator configuration.

settings_class: Optional[Type[BaseSettings]] property

Get the settings class for the Baseten step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The Baseten step operator settings class.

validator: Optional[StackValidator] property

Get the stack validator for the Baseten step operator.

Returns:

Type Description
Optional[StackValidator]

The stack validator.

Methods:
cancel(step_run: StepRunResponse) -> None

Cancel a submitted Baseten training job.

Parameters:

Name Type Description Default
step_run StepRunResponse

The step run.

required
Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
495
496
497
498
499
500
501
502
503
504
505
def cancel(self, step_run: "StepRunResponse") -> None:
    """Cancel a submitted Baseten training job.

    Args:
        step_run: The step run.
    """
    ids = self._extract_job_and_project_id(step_run)
    if ids is None:
        return
    project_id, job_id = ids
    self.api.stop_job(project_id, job_id)
get_docker_builds(snapshot: PipelineSnapshotBase) -> List[BuildConfiguration]

Get the Docker build configurations for the Baseten step operator.

Parameters:

Name Type Description Default
snapshot PipelineSnapshotBase

The pipeline snapshot.

required

Returns:

Type Description
List[BuildConfiguration]

A list of Docker build configurations.

Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def get_docker_builds(
    self, snapshot: "PipelineSnapshotBase"
) -> List["BuildConfiguration"]:
    """Get the Docker build configurations for the Baseten step operator.

    Args:
        snapshot: The pipeline snapshot.

    Returns:
        A list of Docker build configurations.
    """
    builds = []
    for step_name, step in snapshot.step_configurations.items():
        if step.config.uses_step_operator(self.name):
            builds.append(
                BuildConfiguration(
                    key=BASETEN_STEP_OPERATOR_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                )
            )

    return builds
get_status(step_run: StepRunResponse) -> ExecutionStatus

Get the status of a submitted Baseten training job.

Parameters:

Name Type Description Default
step_run StepRunResponse

The step run.

required

Returns:

Type Description
ExecutionStatus

The execution status. Returns FAILED if the job ids are missing or

ExecutionStatus

the job no longer exists.

Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def get_status(self, step_run: "StepRunResponse") -> ExecutionStatus:
    """Get the status of a submitted Baseten training job.

    Args:
        step_run: The step run.

    Returns:
        The execution status. Returns FAILED if the job ids are missing or
        the job no longer exists.
    """
    ids = self._extract_job_and_project_id(step_run)
    if ids is None:
        return ExecutionStatus.FAILED

    project_id, job_id = ids
    state = self.api.get_job_status(project_id, job_id)
    if state is None:
        return ExecutionStatus.FAILED

    return _BASETEN_STATE_TO_EXECUTION_STATUS.get(
        state, ExecutionStatus.RUNNING
    )
submit(info: StepRunInfo, entrypoint_command: List[str], environment: Dict[str, str]) -> None

Submit a step run as a Baseten training job.

Parameters:

Name Type Description Default
info StepRunInfo

The step run information.

required
entrypoint_command List[str]

The entrypoint command for the step.

required
environment Dict[str, str]

The environment variables for the step.

required

Raises:

Type Description
RuntimeError

If multi-node execution is requested for a regular step.

Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def submit(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Submit a step run as a Baseten training job.

    Args:
        info: The step run information.
        entrypoint_command: The entrypoint command for the step.
        environment: The environment variables for the step.

    Raises:
        RuntimeError: If multi-node execution is requested for a regular
            step.
    """
    settings = cast(BasetenStepOperatorSettings, self.get_settings(info))
    is_command_step = info.config.command is not None

    # A multi-node job runs the same entrypoint on every node. A regular
    # step would therefore duplicate its artifacts, outputs and logs across
    # nodes, so only command steps (which own their distributed launch) may
    # scale out. This runs at submit time because dynamic pipelines execute
    # from a server-side snapshot where no compile-time hook is available.
    if settings.node_count > 1 and not is_command_step:
        raise RuntimeError(
            f"The step `{info.pipeline_step_name}` requests "
            f"node_count={settings.node_count} but is a regular step. "
            "Running a regular step on multiple nodes would duplicate its "
            "artifacts, outputs and logs across every node. Use a "
            "CommandStep that owns its distributed launch instead. Wrap "
            "the launcher in a shell so Baseten's per-node variables "
            "(BT_GROUP_SIZE, BT_NODE_RANK, BT_LEADER_ADDR, BT_NUM_GPUS) "
            "expand on each node, e.g. "
            '`CommandStep(command=["bash", "-lc", "torchrun '
            "--nnodes=$BT_GROUP_SIZE --node-rank=$BT_NODE_RANK "
            "--master-addr=$BT_LEADER_ADDR --master-port=29500 "
            '--nproc-per-node=$BT_NUM_GPUS train.py"], '
            f'step_operator="{self.name}")`. Keep node_count=1 for '
            "regular steps; see the Baseten step-operator docs "
            "(multi-node distributed training) for the full pattern."
        )

    image_name = info.get_image(key=BASETEN_STEP_OPERATOR_DOCKER_IMAGE_KEY)

    # Only pass cpu_count/memory when set, since truss
    # `Compute` types them as int/str and applies its own defaults when
    # omitted. gpu_count defaults to 1 only when unset (0 stays 0).
    resources = info.config.resource_settings
    gpu_count = (
        resources.gpu_count if resources.gpu_count is not None else 1
    )
    compute_kwargs: Dict[str, Any] = {
        "node_count": settings.node_count,
        "accelerator": truss_config.AcceleratorSpec(
            accelerator=settings.accelerator,
            count=gpu_count,
        ),
    }
    if resources.cpu_count is not None:
        compute_kwargs["cpu_count"] = int(resources.cpu_count)
    if resources.memory is not None:
        compute_kwargs["memory"] = resources.memory

    # Cache / checkpointing are opt-in (default disabled); only attach the
    # config objects when enabled so the job keeps Baseten's defaults.
    runtime_kwargs: Dict[str, Any] = {
        "start_commands": [shell_join(entrypoint_command)],
        "environment_variables": self._build_environment(
            environment=environment,
            secrets=settings.secrets,
        ),
    }
    if settings.enable_cache:
        runtime_kwargs["cache_config"] = definitions.CacheConfig(
            enabled=True,
            enable_legacy_hf_mount=settings.cache_enable_legacy_hf_mount,
            require_cache_affinity=settings.cache_require_affinity,
        )
    if settings.enable_checkpointing:
        runtime_kwargs["checkpointing_config"] = (
            definitions.CheckpointingConfig(enabled=True)
        )

    project = definitions.TrainingProject(
        name=self.config.project,
        job=definitions.TrainingJob(
            image=definitions.Image(
                base_image=image_name, docker_auth=self._docker_auth()
            ),
            compute=definitions.Compute(**compute_kwargs),
            runtime=definitions.Runtime(**runtime_kwargs),
            # The ZenML image is the entire environment; do not extract an
            # uploaded working directory on top of it.
            enable_baseten_workdir=False,
        ),
    )

    # Push from an empty temp dir so the local working directory is not
    # uploaded: the Docker image is the single source of truth for code.
    self._configure_truss_remote()
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            result = push(
                config=project,
                source_dir=Path(tmpdir),
                remote=BASETEN_REMOTE_NAME,
            )
    except Exception as e:
        raise RuntimeError(
            _explain_submit_error(e, settings, is_command_step)
        ) from e

    job_id = result["id"]
    project_id = result["training_project_id"]
    metadata: Dict[str, Any] = {
        BASETEN_JOB_ID_METADATA_KEY: job_id,
        BASETEN_PROJECT_ID_METADATA_KEY: project_id,
        BASETEN_LOGS_URL_METADATA_KEY: Uri(
            f"{BASETEN_REMOTE_URL}/training/project/{project_id}"
            f"/logs/{job_id}"
        ),
    }

    try:
        publish_step_run_metadata(info.step_run_id, {self.id: metadata})
        info.step_run.run_metadata.update(metadata)
    except Exception:
        logger.error(
            "Failed to persist Baseten job ids for step `%s`. The job is "
            "running on Baseten (job_id=%s, project_id=%s) but status "
            "checks and cancellation will not work without these ids. Stop "
            "it from the Baseten dashboard if needed.",
            info.pipeline_step_name,
            job_id,
            project_id,
        )
Modules
baseten_step_operator

Baseten step operator implementation.

Classes
BasetenStepOperator(*args: Any, **kwargs: Any)

Bases: BaseStepOperator

Step operator that runs a step as a Baseten training job.

Initialize the step operator.

Parameters:

Name Type Description Default
*args Any

Positional arguments forwarded to the base class.

()
**kwargs Any

Keyword arguments forwarded to the base class.

{}
Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
137
138
139
140
141
142
143
144
145
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the step operator.

    Args:
        *args: Positional arguments forwarded to the base class.
        **kwargs: Keyword arguments forwarded to the base class.
    """
    super().__init__(*args, **kwargs)
    self._api: Optional["BasetenApiClient"] = None
Attributes
api: BasetenApiClient property

Lazily constructed Baseten REST API client.

Returns:

Type Description
BasetenApiClient

The Baseten REST API client.

config: BasetenStepOperatorConfig property

Get the Baseten step operator configuration.

Returns:

Type Description
BasetenStepOperatorConfig

The Baseten step operator configuration.

settings_class: Optional[Type[BaseSettings]] property

Get the settings class for the Baseten step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The Baseten step operator settings class.

validator: Optional[StackValidator] property

Get the stack validator for the Baseten step operator.

Returns:

Type Description
Optional[StackValidator]

The stack validator.

Methods:
cancel(step_run: StepRunResponse) -> None

Cancel a submitted Baseten training job.

Parameters:

Name Type Description Default
step_run StepRunResponse

The step run.

required
Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
495
496
497
498
499
500
501
502
503
504
505
def cancel(self, step_run: "StepRunResponse") -> None:
    """Cancel a submitted Baseten training job.

    Args:
        step_run: The step run.
    """
    ids = self._extract_job_and_project_id(step_run)
    if ids is None:
        return
    project_id, job_id = ids
    self.api.stop_job(project_id, job_id)
get_docker_builds(snapshot: PipelineSnapshotBase) -> List[BuildConfiguration]

Get the Docker build configurations for the Baseten step operator.

Parameters:

Name Type Description Default
snapshot PipelineSnapshotBase

The pipeline snapshot.

required

Returns:

Type Description
List[BuildConfiguration]

A list of Docker build configurations.

Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def get_docker_builds(
    self, snapshot: "PipelineSnapshotBase"
) -> List["BuildConfiguration"]:
    """Get the Docker build configurations for the Baseten step operator.

    Args:
        snapshot: The pipeline snapshot.

    Returns:
        A list of Docker build configurations.
    """
    builds = []
    for step_name, step in snapshot.step_configurations.items():
        if step.config.uses_step_operator(self.name):
            builds.append(
                BuildConfiguration(
                    key=BASETEN_STEP_OPERATOR_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                )
            )

    return builds
get_status(step_run: StepRunResponse) -> ExecutionStatus

Get the status of a submitted Baseten training job.

Parameters:

Name Type Description Default
step_run StepRunResponse

The step run.

required

Returns:

Type Description
ExecutionStatus

The execution status. Returns FAILED if the job ids are missing or

ExecutionStatus

the job no longer exists.

Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def get_status(self, step_run: "StepRunResponse") -> ExecutionStatus:
    """Get the status of a submitted Baseten training job.

    Args:
        step_run: The step run.

    Returns:
        The execution status. Returns FAILED if the job ids are missing or
        the job no longer exists.
    """
    ids = self._extract_job_and_project_id(step_run)
    if ids is None:
        return ExecutionStatus.FAILED

    project_id, job_id = ids
    state = self.api.get_job_status(project_id, job_id)
    if state is None:
        return ExecutionStatus.FAILED

    return _BASETEN_STATE_TO_EXECUTION_STATUS.get(
        state, ExecutionStatus.RUNNING
    )
submit(info: StepRunInfo, entrypoint_command: List[str], environment: Dict[str, str]) -> None

Submit a step run as a Baseten training job.

Parameters:

Name Type Description Default
info StepRunInfo

The step run information.

required
entrypoint_command List[str]

The entrypoint command for the step.

required
environment Dict[str, str]

The environment variables for the step.

required

Raises:

Type Description
RuntimeError

If multi-node execution is requested for a regular step.

Source code in src/zenml/integrations/baseten/step_operators/baseten_step_operator.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def submit(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Submit a step run as a Baseten training job.

    Args:
        info: The step run information.
        entrypoint_command: The entrypoint command for the step.
        environment: The environment variables for the step.

    Raises:
        RuntimeError: If multi-node execution is requested for a regular
            step.
    """
    settings = cast(BasetenStepOperatorSettings, self.get_settings(info))
    is_command_step = info.config.command is not None

    # A multi-node job runs the same entrypoint on every node. A regular
    # step would therefore duplicate its artifacts, outputs and logs across
    # nodes, so only command steps (which own their distributed launch) may
    # scale out. This runs at submit time because dynamic pipelines execute
    # from a server-side snapshot where no compile-time hook is available.
    if settings.node_count > 1 and not is_command_step:
        raise RuntimeError(
            f"The step `{info.pipeline_step_name}` requests "
            f"node_count={settings.node_count} but is a regular step. "
            "Running a regular step on multiple nodes would duplicate its "
            "artifacts, outputs and logs across every node. Use a "
            "CommandStep that owns its distributed launch instead. Wrap "
            "the launcher in a shell so Baseten's per-node variables "
            "(BT_GROUP_SIZE, BT_NODE_RANK, BT_LEADER_ADDR, BT_NUM_GPUS) "
            "expand on each node, e.g. "
            '`CommandStep(command=["bash", "-lc", "torchrun '
            "--nnodes=$BT_GROUP_SIZE --node-rank=$BT_NODE_RANK "
            "--master-addr=$BT_LEADER_ADDR --master-port=29500 "
            '--nproc-per-node=$BT_NUM_GPUS train.py"], '
            f'step_operator="{self.name}")`. Keep node_count=1 for '
            "regular steps; see the Baseten step-operator docs "
            "(multi-node distributed training) for the full pattern."
        )

    image_name = info.get_image(key=BASETEN_STEP_OPERATOR_DOCKER_IMAGE_KEY)

    # Only pass cpu_count/memory when set, since truss
    # `Compute` types them as int/str and applies its own defaults when
    # omitted. gpu_count defaults to 1 only when unset (0 stays 0).
    resources = info.config.resource_settings
    gpu_count = (
        resources.gpu_count if resources.gpu_count is not None else 1
    )
    compute_kwargs: Dict[str, Any] = {
        "node_count": settings.node_count,
        "accelerator": truss_config.AcceleratorSpec(
            accelerator=settings.accelerator,
            count=gpu_count,
        ),
    }
    if resources.cpu_count is not None:
        compute_kwargs["cpu_count"] = int(resources.cpu_count)
    if resources.memory is not None:
        compute_kwargs["memory"] = resources.memory

    # Cache / checkpointing are opt-in (default disabled); only attach the
    # config objects when enabled so the job keeps Baseten's defaults.
    runtime_kwargs: Dict[str, Any] = {
        "start_commands": [shell_join(entrypoint_command)],
        "environment_variables": self._build_environment(
            environment=environment,
            secrets=settings.secrets,
        ),
    }
    if settings.enable_cache:
        runtime_kwargs["cache_config"] = definitions.CacheConfig(
            enabled=True,
            enable_legacy_hf_mount=settings.cache_enable_legacy_hf_mount,
            require_cache_affinity=settings.cache_require_affinity,
        )
    if settings.enable_checkpointing:
        runtime_kwargs["checkpointing_config"] = (
            definitions.CheckpointingConfig(enabled=True)
        )

    project = definitions.TrainingProject(
        name=self.config.project,
        job=definitions.TrainingJob(
            image=definitions.Image(
                base_image=image_name, docker_auth=self._docker_auth()
            ),
            compute=definitions.Compute(**compute_kwargs),
            runtime=definitions.Runtime(**runtime_kwargs),
            # The ZenML image is the entire environment; do not extract an
            # uploaded working directory on top of it.
            enable_baseten_workdir=False,
        ),
    )

    # Push from an empty temp dir so the local working directory is not
    # uploaded: the Docker image is the single source of truth for code.
    self._configure_truss_remote()
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            result = push(
                config=project,
                source_dir=Path(tmpdir),
                remote=BASETEN_REMOTE_NAME,
            )
    except Exception as e:
        raise RuntimeError(
            _explain_submit_error(e, settings, is_command_step)
        ) from e

    job_id = result["id"]
    project_id = result["training_project_id"]
    metadata: Dict[str, Any] = {
        BASETEN_JOB_ID_METADATA_KEY: job_id,
        BASETEN_PROJECT_ID_METADATA_KEY: project_id,
        BASETEN_LOGS_URL_METADATA_KEY: Uri(
            f"{BASETEN_REMOTE_URL}/training/project/{project_id}"
            f"/logs/{job_id}"
        ),
    }

    try:
        publish_step_run_metadata(info.step_run_id, {self.id: metadata})
        info.step_run.run_metadata.update(metadata)
    except Exception:
        logger.error(
            "Failed to persist Baseten job ids for step `%s`. The job is "
            "running on Baseten (job_id=%s, project_id=%s) but status "
            "checks and cancellation will not work without these ids. Stop "
            "it from the Baseten dashboard if needed.",
            info.pipeline_step_name,
            job_id,
            project_id,
        )
Functions: