Gcp
zenml.integrations.gcp
special
Initialization of the GCP ZenML integration.
The GCP integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores
and provides an io
module to handle file operations on Google Cloud Storage
(GCS).
Additionally, the GCP secrets manager integration submodule provides a way to access the GCP secrets manager from within your ZenML Pipeline runs.
The Vertex AI integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.
GcpIntegration (Integration)
Definition of Google Cloud Platform integration for ZenML.
Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
"""Definition of Google Cloud Platform integration for ZenML."""
NAME = GCP
REQUIREMENTS = [
"kfp==1.8.22", # Only 1.x version that supports pyyaml 6
"gcsfs",
"google-cloud-secret-manager",
"google-cloud-container>=2.21.0",
"google-cloud-storage>=2.9.0",
"google-cloud-aiplatform>=1.21.0", # includes shapely pin fix
"google-cloud-scheduler>=2.7.3",
"google-cloud-functions>=1.8.3",
"google-cloud-build>=3.11.0",
"kubernetes",
]
@staticmethod
def activate() -> None:
"""Activate the GCP integration."""
from zenml.integrations.gcp import service_connectors # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the GCP integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.gcp.flavors import (
GCPArtifactStoreFlavor,
GCPImageBuilderFlavor,
GCPSecretsManagerFlavor,
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
)
return [
GCPArtifactStoreFlavor,
GCPImageBuilderFlavor,
GCPSecretsManagerFlavor,
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
]
activate()
staticmethod
Activate the GCP integration.
Source code in zenml/integrations/gcp/__init__.py
@staticmethod
def activate() -> None:
"""Activate the GCP integration."""
from zenml.integrations.gcp import service_connectors # noqa
flavors()
classmethod
Declare the stack component flavors for the GCP integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/gcp/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the GCP integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.gcp.flavors import (
GCPArtifactStoreFlavor,
GCPImageBuilderFlavor,
GCPSecretsManagerFlavor,
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
)
return [
GCPArtifactStoreFlavor,
GCPImageBuilderFlavor,
GCPSecretsManagerFlavor,
VertexOrchestratorFlavor,
VertexStepOperatorFlavor,
]
artifact_stores
special
Initialization of the GCP Artifact Store.
gcp_artifact_store
Implementation of the GCP Artifact Store.
GCPArtifactStore (BaseArtifactStore, AuthenticationMixin)
Artifact Store for Google Cloud Storage based artifacts.
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for Google Cloud Storage based artifacts."""
_filesystem: Optional[gcsfs.GCSFileSystem] = None
@property
def config(self) -> GCPArtifactStoreConfig:
"""Returns the `GCPArtifactStoreConfig` config.
Returns:
The configuration.
"""
return cast(GCPArtifactStoreConfig, self._config)
def get_credentials(
self,
) -> Optional[Union[Dict[str, Any], gcp_credentials.Credentials]]:
"""Returns the credentials for the GCP Artifact Store if configured.
Returns:
The credentials.
Raises:
RuntimeError: If the linked connector returns the wrong type of
client.
"""
connector = self.get_connector()
if connector:
client = connector.connect()
if not isinstance(client, storage.Client):
raise RuntimeError(
f"Expected a google.cloud.storage.Client while trying to "
f"use the linked connector, but got {type(client)}."
)
return client._credentials
secret = self.get_authentication_secret(
expected_schema_type=GCPSecretSchema
)
return secret.get_credential_dict() if secret else None
@property
def filesystem(self) -> gcsfs.GCSFileSystem:
"""The gcsfs filesystem to access this artifact store.
Returns:
The gcsfs filesystem to access this artifact store.
"""
# Refresh the credentials also if the connector has expired
if self._filesystem and not self.connector_has_expired():
return self._filesystem
token = self.get_credentials()
self._filesystem = gcsfs.GCSFileSystem(token=token)
return self._filesystem
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object that can be used to read or write to the file.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [
f"{GCP_PATH_PREFIX}{path}"
for path in self.filesystem.glob(path=pattern)
]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path of the directory to list.
Returns:
A list of paths of files in the directory.
"""
path_without_prefix = convert_to_str(path)
if path_without_prefix.startswith(GCP_PATH_PREFIX):
path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by GCP.
Args:
file_dict: A file info dict returned by the GCP filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path_without_prefix) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
# gcsfs.listdir also returns the root directory, so we filter
# it out here
if _extract_basename(dict_)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: the path to get stat info for.
Returns:
A dictionary with the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def size(self, path: PathType) -> int:
"""Get the size of a file in bytes.
Args:
path: The path to the file.
Returns:
The size of the file in bytes.
"""
return self.filesystem.size(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
config: GCPArtifactStoreConfig
property
readonly
Returns the GCPArtifactStoreConfig
config.
Returns:
Type | Description |
---|---|
GCPArtifactStoreConfig |
The configuration. |
filesystem: GCSFileSystem
property
readonly
The gcsfs filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
GCSFileSystem |
The gcsfs filesystem to access this artifact store. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
get_credentials(self)
Returns the credentials for the GCP Artifact Store if configured.
Returns:
Type | Description |
---|---|
Union[Dict[str, Any], google.oauth2.credentials.Credentials] |
The credentials. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the linked connector returns the wrong type of client. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def get_credentials(
self,
) -> Optional[Union[Dict[str, Any], gcp_credentials.Credentials]]:
"""Returns the credentials for the GCP Artifact Store if configured.
Returns:
The credentials.
Raises:
RuntimeError: If the linked connector returns the wrong type of
client.
"""
connector = self.get_connector()
if connector:
client = connector.connect()
if not isinstance(client, storage.Client):
raise RuntimeError(
f"Expected a google.cloud.storage.Client while trying to "
f"use the linked connector, but got {type(client)}."
)
return client._credentials
secret = self.get_authentication_secret(
expected_schema_type=GCPSecretSchema
)
return secret.get_credential_dict() if secret else None
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
return [
f"{GCP_PATH_PREFIX}{path}"
for path in self.filesystem.glob(path=pattern)
]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths of files in the directory. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path of the directory to list.
Returns:
A list of paths of files in the directory.
"""
path_without_prefix = convert_to_str(path)
if path_without_prefix.startswith(GCP_PATH_PREFIX):
path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a file info dict returned by GCP.
Args:
file_dict: A file info dict returned by the GCP filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path_without_prefix) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
# gcsfs.listdir also returns the root directory, so we filter
# it out here
if _extract_basename(dict_)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to create. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to create. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path of the directory to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object that can be used to read or write to the file. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object that can be used to read or write to the file.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the file to remove. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path of the file to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
size(self, path)
Get the size of a file in bytes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to the file. |
required |
Returns:
Type | Description |
---|---|
int |
The size of the file in bytes. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def size(self, path: PathType) -> int:
"""Get the size of a file in bytes.
Args:
path: The path to the file.
Returns:
The size of the file in bytes.
"""
return self.filesystem.size(path=path) # type: ignore[no-any-return]
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
the path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary with the stat info. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: the path to get stat info for.
Returns:
A dictionary with the stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
constants
Constants for the VertexAI integration.
flavors
special
GCP integration flavors.
gcp_artifact_store_flavor
GCP artifact store flavor.
GCPArtifactStoreConfig (BaseArtifactStoreConfig, AuthenticationConfigMixin)
pydantic-model
Configuration for GCP Artifact Store.
Source code in zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py
class GCPArtifactStoreConfig(
BaseArtifactStoreConfig, AuthenticationConfigMixin
):
"""Configuration for GCP Artifact Store."""
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX}
GCPArtifactStoreFlavor (BaseArtifactStoreFlavor)
Flavor of the GCP artifact store.
Source code in zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py
class GCPArtifactStoreFlavor(BaseArtifactStoreFlavor):
"""Flavor of the GCP artifact store."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return GCP_ARTIFACT_STORE_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(
resource_type="gcs-bucket",
resource_id_attr="path",
)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png"
@property
def config_class(self) -> Type[GCPArtifactStoreConfig]:
"""Returns GCPArtifactStoreConfig config class.
Returns:
The config class.
"""
return GCPArtifactStoreConfig
@property
def implementation_class(self) -> Type["GCPArtifactStore"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.artifact_stores import GCPArtifactStore
return GCPArtifactStore
config_class: Type[zenml.integrations.gcp.flavors.gcp_artifact_store_flavor.GCPArtifactStoreConfig]
property
readonly
Returns GCPArtifactStoreConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.gcp_artifact_store_flavor.GCPArtifactStoreConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[GCPArtifactStore]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GCPArtifactStore] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
gcp_image_builder_flavor
Google Cloud image builder flavor.
GCPImageBuilderConfig (BaseImageBuilderConfig, GoogleCredentialsConfigMixin)
pydantic-model
Google Cloud Builder image builder configuration.
Attributes:
Name | Type | Description |
---|---|---|
cloud_builder_image |
str |
The name of the Docker image to use for the build
steps. Defaults to |
network |
str |
The network name to which the build container will be
attached while building the Docker image. More information about
this:
https://cloud.google.com/build/docs/build-config-file-schema#network.
Defaults to |
build_timeout |
PositiveInt |
The timeout of the build in seconds. More information
about this parameter:
https://cloud.google.com/build/docs/build-config-file-schema#timeout_2
Defaults to |
Source code in zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py
class GCPImageBuilderConfig(
BaseImageBuilderConfig, GoogleCredentialsConfigMixin
):
"""Google Cloud Builder image builder configuration.
Attributes:
cloud_builder_image: The name of the Docker image to use for the build
steps. Defaults to `gcr.io/cloud-builders/docker`.
network: The network name to which the build container will be
attached while building the Docker image. More information about
this:
https://cloud.google.com/build/docs/build-config-file-schema#network.
Defaults to `cloudbuild`.
build_timeout: The timeout of the build in seconds. More information
about this parameter:
https://cloud.google.com/build/docs/build-config-file-schema#timeout_2
Defaults to `3600`.
"""
cloud_builder_image: str = DEFAULT_CLOUD_BUILDER_IMAGE
network: str = DEFAULT_CLOUD_BUILDER_NETWORK
build_timeout: PositiveInt = DEFAULT_CLOUD_BUILD_TIMEOUT
GCPImageBuilderFlavor (BaseImageBuilderFlavor)
Google Cloud Builder image builder flavor.
Source code in zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py
class GCPImageBuilderFlavor(BaseImageBuilderFlavor):
"""Google Cloud Builder image builder flavor."""
@property
def name(self) -> str:
"""The flavor name.
Returns:
The name of the flavor.
"""
return GCP_IMAGE_BUILDER_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(
connector_type="gcp",
resource_type="gcp-generic",
)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/gcp.png"
@property
def config_class(self) -> Type[BaseImageBuilderConfig]:
"""The config class.
Returns:
The config class.
"""
return GCPImageBuilderConfig
@property
def implementation_class(self) -> Type["GCPImageBuilder"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.image_builders import GCPImageBuilder
return GCPImageBuilder
config_class: Type[zenml.image_builders.base_image_builder.BaseImageBuilderConfig]
property
readonly
The config class.
Returns:
Type | Description |
---|---|
Type[zenml.image_builders.base_image_builder.BaseImageBuilderConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[GCPImageBuilder]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[GCPImageBuilder] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
The flavor name.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
gcp_secrets_manager_flavor
GCP secrets manager flavor.
GCPSecretsManagerConfig (BaseSecretsManagerConfig)
pydantic-model
Configuration for the GCP Secrets Manager.
Attributes:
Name | Type | Description |
---|---|---|
project_id |
str |
This is necessary to access the correct GCP project. The project_id of your GCP project space that contains the Secret Manager. |
Source code in zenml/integrations/gcp/flavors/gcp_secrets_manager_flavor.py
class GCPSecretsManagerConfig(BaseSecretsManagerConfig):
"""Configuration for the GCP Secrets Manager.
Attributes:
project_id: This is necessary to access the correct GCP project.
The project_id of your GCP project space that contains the Secret
Manager.
"""
SUPPORTS_SCOPING: ClassVar[bool] = True
project_id: str
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
validate_gcp_secret_name_or_namespace(namespace)
GCPSecretsManagerFlavor (BaseSecretsManagerFlavor)
Class for the GCPSecretsManagerFlavor
.
Source code in zenml/integrations/gcp/flavors/gcp_secrets_manager_flavor.py
class GCPSecretsManagerFlavor(BaseSecretsManagerFlavor):
"""Class for the `GCPSecretsManagerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return GCP_SECRETS_MANAGER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/secrets_managers/gcp.png"
@property
def config_class(self) -> Type[GCPSecretsManagerConfig]:
"""Returns GCPSecretsManagerConfig config class.
Returns:
The config class.
"""
return GCPSecretsManagerConfig
@property
def implementation_class(self) -> Type["GCPSecretsManager"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.secrets_manager import GCPSecretsManager
return GCPSecretsManager
config_class: Type[zenml.integrations.gcp.flavors.gcp_secrets_manager_flavor.GCPSecretsManagerConfig]
property
readonly
Returns GCPSecretsManagerConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.gcp_secrets_manager_flavor.GCPSecretsManagerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[GCPSecretsManager]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[GCPSecretsManager] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
validate_gcp_secret_name_or_namespace(name)
Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters and can contain uppercase and lowercase letters, numerals, and the hyphen (-) and underscore (_) characters. For scoped secrets, we have to limit the size of the name and namespace even further to allow space for both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are also limited by the limitation that Google imposes on label values: max 63 characters and must only contain lowercase letters, numerals and the hyphen (-) and underscore (_) characters
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/gcp/flavors/gcp_secrets_manager_flavor.py
def validate_gcp_secret_name_or_namespace(name: str) -> None:
"""Validate a secret name or namespace.
A Google secret ID is a string with a maximum length of 255 characters
and can contain uppercase and lowercase letters, numerals, and the
hyphen (-) and underscore (_) characters. For scoped secrets, we have to
limit the size of the name and namespace even further to allow space for
both in the Google secret ID.
Given that we also save secret names and namespaces as labels, we are
also limited by the limitation that Google imposes on label values: max
63 characters and must only contain lowercase letters, numerals
and the hyphen (-) and underscore (_) characters
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-z0-9_\-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only lowercase alphanumeric characters and the hyphen (-) and "
f"underscore (_) characters."
)
if name and len(name) > 63:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 63 characters."
)
vertex_orchestrator_flavor
Vertex orchestrator flavor.
VertexOrchestratorConfig (BaseOrchestratorConfig, GoogleCredentialsConfigMixin, VertexOrchestratorSettings)
pydantic-model
Configuration for the Vertex orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
location |
str |
Name of GCP region where the pipeline job will be executed. Vertex AI Pipelines is available in the following regions: https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability |
pipeline_root |
Optional[str] |
a Cloud Storage URI that will be used by the Vertex AI
Pipelines. If not provided but the artifact store in the stack used
to execute the pipeline is a
|
encryption_spec_key_name |
Optional[str] |
The Cloud KMS resource identifier of the
customer managed encryption key used to protect the job. Has the form:
|
workload_service_account |
Optional[str] |
the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If not provided, the Compute Engine default service account for the GCP project in which the pipeline is running is used. |
function_service_account |
Optional[str] |
the service account for cloud function run-as account, for scheduled pipelines. This service account must have the act-as permission on the workload_service_account. If not provided, the Compute Engine default service account for the GCP project in which the pipeline is running is used. |
scheduler_service_account |
Optional[str] |
the service account used by the Google Cloud Scheduler to trigger and authenticate to the pipeline Cloud Function on a schedule. If not provided, the Compute Engine default service account for the GCP project in which the pipeline is running is used. |
network |
Optional[str] |
the full name of the Compute Engine Network to which the job
should be peered. For example, |
cpu_limit |
Optional[str] |
The maximum CPU limit for this operator. This string value can be a number (integer value for number of CPUs) as string, or a number followed by "m", which means 1/1000. You can specify at most 96 CPUs. (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types) |
memory_limit |
Optional[str] |
The maximum memory limit for this operator. This string value can be a number, or a number followed by "K" (kilobyte), "M" (megabyte), or "G" (gigabyte). At most 624GB is supported. |
gpu_limit |
Optional[int] |
The GPU limit (positive number) for the operator. For more information about GPU resources, see: https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus |
Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseOrchestratorConfig,
GoogleCredentialsConfigMixin,
VertexOrchestratorSettings,
):
"""Configuration for the Vertex orchestrator.
Attributes:
location: Name of GCP region where the pipeline job will be executed.
Vertex AI Pipelines is available in the following regions:
https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability
pipeline_root: a Cloud Storage URI that will be used by the Vertex AI
Pipelines. If not provided but the artifact store in the stack used
to execute the pipeline is a
`zenml.integrations.gcp.artifact_stores.GCPArtifactStore`,
then a subdirectory of the artifact store will be used.
encryption_spec_key_name: The Cloud KMS resource identifier of the
customer managed encryption key used to protect the job. Has the form:
`projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY>`
. The key needs to be in the same region as where the compute
resource is created.
workload_service_account: the service account for workload run-as
account. Users submitting jobs must have act-as permission on this
run-as account. If not provided, the Compute Engine default service
account for the GCP project in which the pipeline is running is
used.
function_service_account: the service account for cloud function run-as
account, for scheduled pipelines. This service account must have
the act-as permission on the workload_service_account.
If not provided, the Compute Engine default service account for the
GCP project in which the pipeline is running is used.
scheduler_service_account: the service account used by the Google Cloud
Scheduler to trigger and authenticate to the pipeline Cloud Function
on a schedule. If not provided, the Compute Engine default service
account for the GCP project in which the pipeline is running is
used.
network: the full name of the Compute Engine Network to which the job
should be peered. For example, `projects/12345/global/networks/myVPC`
If not provided, the job will not be peered with any network.
cpu_limit: The maximum CPU limit for this operator. This string value
can be a number (integer value for number of CPUs) as string,
or a number followed by "m", which means 1/1000. You can specify
at most 96 CPUs.
(see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)
memory_limit: The maximum memory limit for this operator. This string
value can be a number, or a number followed by "K" (kilobyte),
"M" (megabyte), or "G" (gigabyte). At most 624GB is supported.
gpu_limit: The GPU limit (positive number) for the operator.
For more information about GPU resources, see:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
"""
location: str
pipeline_root: Optional[str] = None
encryption_spec_key_name: Optional[str] = None
workload_service_account: Optional[str] = None
function_service_account: Optional[str] = None
scheduler_service_account: Optional[str] = None
network: Optional[str] = None
cpu_limit: Optional[str] = None
memory_limit: Optional[str] = None
gpu_limit: Optional[int] = None
_resource_deprecation = deprecation_utils.deprecate_pydantic_attributes(
"cpu_limit", "memory_limit", "gpu_limit"
)
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
VertexOrchestratorFlavor (BaseOrchestratorFlavor)
Vertex Orchestrator flavor.
Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorFlavor(BaseOrchestratorFlavor):
"""Vertex Orchestrator flavor."""
@property
def name(self) -> str:
"""Name of the orchestrator flavor.
Returns:
Name of the orchestrator flavor.
"""
return GCP_VERTEX_ORCHESTRATOR_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(
resource_type="gcp-generic",
)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/vertexai.png"
@property
def config_class(self) -> Type[VertexOrchestratorConfig]:
"""Returns VertexOrchestratorConfig config class.
Returns:
The config class.
"""
return VertexOrchestratorConfig
@property
def implementation_class(self) -> Type["VertexOrchestrator"]:
"""Implementation class for this flavor.
Returns:
Implementation class for this flavor.
"""
from zenml.integrations.gcp.orchestrators import VertexOrchestrator
return VertexOrchestrator
config_class: Type[zenml.integrations.gcp.flavors.vertex_orchestrator_flavor.VertexOrchestratorConfig]
property
readonly
Returns VertexOrchestratorConfig config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.vertex_orchestrator_flavor.VertexOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[VertexOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[VertexOrchestrator] |
Implementation class for this flavor. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the orchestrator flavor.
Returns:
Type | Description |
---|---|
str |
Name of the orchestrator flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
VertexOrchestratorSettings (BaseSettings)
pydantic-model
Settings for the Vertex orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
synchronous |
bool |
If |
labels |
Dict[str, str] |
Labels to assign to the pipeline job. |
node_selector_constraint |
Optional[Tuple[str, str]] |
Each constraint is a key-value pair label. For the container to be eligible to run on a node, the node must have each of the constraints appeared as labels. For example a GPU type can be providing by one of the following tuples: - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100") Hint: the selected region (location) must provide the requested accelerator (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones). |
pod_settings |
Optional[zenml.integrations.kubernetes.pod_settings.KubernetesPodSettings] |
Pod settings to apply. |
Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorSettings(BaseSettings):
"""Settings for the Vertex orchestrator.
Attributes:
synchronous: If `True`, running a pipeline using this orchestrator will
block until all steps finished running on Vertex AI Pipelines
service.
labels: Labels to assign to the pipeline job.
node_selector_constraint: Each constraint is a key-value pair label.
For the container to be eligible to run on a node, the node must have
each of the constraints appeared as labels.
For example a GPU type can be providing by one of the following tuples:
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4")
- ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100")
Hint: the selected region (location) must provide the requested accelerator
(see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).
pod_settings: Pod settings to apply.
"""
labels: Dict[str, str] = {}
synchronous: bool = False
node_selector_constraint: Optional[Tuple[str, str]] = None
pod_settings: Optional[KubernetesPodSettings] = None
_node_selector_deprecation = (
deprecation_utils.deprecate_pydantic_attributes(
"node_selector_constraint"
)
)
vertex_step_operator_flavor
Vertex step operator flavor.
VertexStepOperatorConfig (BaseStepOperatorConfig, GoogleCredentialsConfigMixin, VertexStepOperatorSettings)
pydantic-model
Configuration for the Vertex step operator.
Attributes:
Name | Type | Description |
---|---|---|
region |
str |
Region name, e.g., |
encryption_spec_key_name |
Optional[str] |
Encryption spec key name. |
Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorConfig( # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
BaseStepOperatorConfig,
GoogleCredentialsConfigMixin,
VertexStepOperatorSettings,
):
"""Configuration for the Vertex step operator.
Attributes:
region: Region name, e.g., `europe-west1`.
encryption_spec_key_name: Encryption spec key name.
"""
region: str
# customer managed encryption key resource name
# will be applied to all Vertex AI resources if set
encryption_spec_key_name: Optional[str] = None
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be
used with a local ZenML database or if it requires a remote ZenML
server.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
VertexStepOperatorFlavor (BaseStepOperatorFlavor)
Vertex Step Operator flavor.
Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorFlavor(BaseStepOperatorFlavor):
"""Vertex Step Operator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
Name of the flavor.
"""
return GCP_VERTEX_STEP_OPERATOR_FLAVOR
@property
def service_connector_requirements(
self,
) -> Optional[ServiceConnectorRequirements]:
"""Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available
service connector types that are compatible with this flavor.
Returns:
Requirements for compatible service connectors, if a service
connector is required for this flavor.
"""
return ServiceConnectorRequirements(
resource_type="gcp-generic",
)
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/vertexai.png"
@property
def config_class(self) -> Type[VertexStepOperatorConfig]:
"""Returns `VertexStepOperatorConfig` config class.
Returns:
The config class.
"""
return VertexStepOperatorConfig
@property
def implementation_class(self) -> Type["VertexStepOperator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.gcp.step_operators import VertexStepOperator
return VertexStepOperator
config_class: Type[zenml.integrations.gcp.flavors.vertex_step_operator_flavor.VertexStepOperatorConfig]
property
readonly
Returns VertexStepOperatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.gcp.flavors.vertex_step_operator_flavor.VertexStepOperatorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[VertexStepOperator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[VertexStepOperator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
Name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]
property
readonly
Service connector resource requirements for service connectors.
Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.
Returns:
Type | Description |
---|---|
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] |
Requirements for compatible service connectors, if a service connector is required for this flavor. |
VertexStepOperatorSettings (BaseSettings)
pydantic-model
Settings for the Vertex step operator.
Attributes:
Name | Type | Description |
---|---|---|
accelerator_type |
Optional[str] |
Defines which accelerator (GPU, TPU) is used for the job. Check out out this table to see which accelerator type and count are compatible with your chosen machine type: https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table. |
accelerator_count |
int |
Defines number of accelerators to be used for the job. Check out out this table to see which accelerator type and count are compatible with your chosen machine type: https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table. |
machine_type |
str |
Machine type specified here https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types. |
Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorSettings(BaseSettings):
"""Settings for the Vertex step operator.
Attributes:
accelerator_type: Defines which accelerator (GPU, TPU) is used for the
job. Check out out this table to see which accelerator
type and count are compatible with your chosen machine type:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
accelerator_count: Defines number of accelerators to be used for the
job. Check out out this table to see which accelerator
type and count are compatible with your chosen machine type:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
machine_type: Machine type specified here
https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
"""
accelerator_type: Optional[str] = None
accelerator_count: int = 0
machine_type: str = "n1-standard-4"
google_cloud_function
Utils for the Google Cloud Functions API.
create_cloud_function(directory_path, upload_path, project, location, function_name, credentials=None, function_service_account_email=None, timeout=1800)
Create google cloud function from specified directory path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
directory_path |
str |
Local path to directory where function code resides. |
required |
upload_path |
str |
GCS path where to upload the function code. |
required |
project |
str |
GCP project ID. |
required |
location |
str |
GCP location name. |
required |
function_name |
str |
Name of the function to create. |
required |
credentials |
Optional[Credentials] |
Credentials to use for GCP services. |
None |
function_service_account_email |
Optional[str] |
The service account email the function will run with. |
None |
timeout |
int |
Timeout in seconds. |
1800 |
Returns:
Type | Description |
---|---|
str |
URI of the created cloud function. |
Exceptions:
Type | Description |
---|---|
TimeoutError |
If function times out. |
RuntimeError |
If scheduling runs into a problem. |
Source code in zenml/integrations/gcp/google_cloud_function.py
def create_cloud_function(
directory_path: str,
upload_path: str,
project: str,
location: str,
function_name: str,
credentials: Optional["Credentials"] = None,
function_service_account_email: Optional[str] = None,
timeout: int = 1800,
) -> str:
"""Create google cloud function from specified directory path.
Args:
directory_path: Local path to directory where function code resides.
upload_path: GCS path where to upload the function code.
project: GCP project ID.
location: GCP location name.
function_name: Name of the function to create.
credentials: Credentials to use for GCP services.
function_service_account_email: The service account email the function will run with.
timeout: Timeout in seconds.
Returns:
str: URI of the created cloud function.
Raises:
TimeoutError: If function times out.
RuntimeError: If scheduling runs into a problem.
"""
sanitized_function_name = function_name.replace("_", "-")
parent = f"projects/{project}/locations/{location}"
function_full_name = f"{parent}/functions/{sanitized_function_name}"
logger.info(f"Creating Google Cloud Function: {function_full_name}")
storage_source = upload_directory(directory_path, upload_path)
# Make the request
get_cloud_functions_api(credentials=credentials).create_function(
request=CreateFunctionRequest(
parent=parent,
function_id=sanitized_function_name,
function=Function(
name=function_full_name,
build_config=BuildConfig(
entry_point="trigger_vertex_job",
runtime="python38",
source=Source(storage_source=storage_source),
),
service_config=ServiceConfig(
service_account_email=function_service_account_email
)
if function_service_account_email
else None,
),
)
)
state = Function.State.DEPLOYING
logger.info(
"Creating cloud function to run pipeline... This might take a few "
"minutes. Please do not exit the program at this point..."
)
start_time = time.time()
while state == Function.State.DEPLOYING:
response = get_cloud_functions_api(
credentials=credentials
).get_function(request=GetFunctionRequest(name=function_full_name))
state = response.state
logger.info("Still creating... sleeping for 5 seconds...")
time.sleep(5)
if time.time() - start_time > timeout:
raise TimeoutError("Timed out waiting for function to deploy!")
if state != Function.State.ACTIVE:
error_messages = ", ".join(
[msg.message for msg in response.state_messages]
)
raise RuntimeError(
f"Scheduling failed with the following messages: {error_messages}"
)
logger.info(f"Done! Function available at {response.service_config.uri}")
return str(response.service_config.uri)
get_cloud_functions_api(credentials=None)
Gets the cloud functions API resource client.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
credentials |
Optional[Credentials] |
Google cloud credentials. |
None |
Returns:
Type | Description |
---|---|
FunctionServiceClient |
Cloud Functions V2 Client. |
Source code in zenml/integrations/gcp/google_cloud_function.py
def get_cloud_functions_api(
credentials: Optional["Credentials"] = None,
) -> functions_v2.FunctionServiceClient:
"""Gets the cloud functions API resource client.
Args:
credentials: Google cloud credentials.
Returns:
Cloud Functions V2 Client.
"""
return functions_v2.FunctionServiceClient(credentials=credentials)
upload_directory(directory_path, upload_path)
Uploads local directory to remote one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
upload_path |
str |
GCS path where to upload the zipped function code. |
required |
directory_path |
str |
Local path of directory to upload. |
required |
Returns:
Type | Description |
---|---|
Storage source (https |
//cloud.google.com/functions/docs/reference/rest/v2/projects.locations.functions#StorageSource). |
Source code in zenml/integrations/gcp/google_cloud_function.py
def upload_directory(
directory_path: str,
upload_path: str,
) -> StorageSource:
"""Uploads local directory to remote one.
Args:
upload_path: GCS path where to upload the zipped function code.
directory_path: Local path of directory to upload.
Returns:
Storage source (https://cloud.google.com/functions/docs/reference/rest/v2/projects.locations.functions#StorageSource).
"""
with tempfile.NamedTemporaryFile(delete=False) as f:
with open(f.name, "wb") as data:
with zipfile.ZipFile(data, "w", zipfile.ZIP_DEFLATED) as archive:
zipdir(directory_path, archive)
data.seek(0)
# Copy and remove
fileio.copy(f.name, upload_path, overwrite=True)
fileio.remove(f.name)
# Split the path by "/" character
bucket, object_path = upload_path.replace("gs://", "").split(
"/", maxsplit=1
)
return StorageSource(
bucket=bucket,
object_=object_path,
)
zipdir(path, ziph)
Zips a directory using an Zipfile object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Path to zip directory to. |
required |
ziph |
ZipFile |
A |
required |
Source code in zenml/integrations/gcp/google_cloud_function.py
def zipdir(path: str, ziph: zipfile.ZipFile) -> None:
"""Zips a directory using an Zipfile object.
Args:
path: Path to zip directory to.
ziph: A `zipfile.Zipfile` file object.
"""
for root, _, files in os.walk(path):
for file in files:
if file != "__init__.py":
ziph.write(os.path.join(root, file), file)
google_cloud_scheduler
Utils for the Google Cloud Scheduler API.
create_scheduler_job(project, region, http_uri, service_account_email, body, credentials=None, schedule='* * * * *', time_zone='Etc/UTC')
Creates a Google Cloud Scheduler job.
Job periodically sends POST request to the specified HTTP URI on a schedule.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project |
str |
GCP project ID. |
required |
region |
str |
GCP region. |
required |
http_uri |
str |
HTTP URI of the cloud function to call. |
required |
service_account_email |
str |
Service account email to use to authenticate to the Google Cloud Function through an OIDC token. |
required |
body |
Dict[str, Union[Dict[str, str], bool, str]] |
The body of values to send to the cloud function in the POST call. |
required |
schedule |
str |
Cron expression of the schedule. Defaults to " * * * ". |
'* * * * *' |
time_zone |
str |
Time zone of the schedule. Defaults to "Etc/UTC". |
'Etc/UTC' |
credentials |
Optional[Credentials] |
Credentials to use for GCP services. |
None |
Source code in zenml/integrations/gcp/google_cloud_scheduler.py
def create_scheduler_job(
project: str,
region: str,
http_uri: str,
service_account_email: str,
body: Dict[str, Union[Dict[str, str], bool, str, None]],
credentials: Optional["Credentials"] = None,
schedule: str = "* * * * *",
time_zone: str = "Etc/UTC",
) -> None:
"""Creates a Google Cloud Scheduler job.
Job periodically sends POST request to the specified HTTP URI on a schedule.
Args:
project: GCP project ID.
region: GCP region.
http_uri: HTTP URI of the cloud function to call.
service_account_email: Service account email to use to authenticate to
the Google Cloud Function through an OIDC token.
body: The body of values to send to the cloud function in the POST call.
schedule: Cron expression of the schedule. Defaults to "* * * * *".
time_zone: Time zone of the schedule. Defaults to "Etc/UTC".
credentials: Credentials to use for GCP services.
"""
# Create a client.
client = scheduler.CloudSchedulerClient(credentials=credentials)
# Construct the fully qualified location path.
parent = f"projects/{project}/locations/{region}"
# Use the client to send the job creation request.
job = client.create_job(
request=CreateJobRequest(
parent=parent,
job=Job(
http_target=HttpTarget(
uri=http_uri,
body=json.dumps(body).encode(),
http_method=HttpMethod.POST,
oidc_token=OidcToken(
service_account_email=service_account_email
),
),
schedule=schedule,
time_zone=time_zone,
),
)
)
logging.debug(f"Created scheduler job. Response: {job}")
google_credentials_mixin
Implementation of the Google credentials mixin.
GoogleCredentialsConfigMixin (StackComponentConfig)
pydantic-model
Config mixin for Google Cloud Platform credentials.
Attributes:
Name | Type | Description |
---|---|---|
project |
Optional[str] |
GCP project name. If |
service_account_path |
Optional[str] |
path to the service account credentials file to be used for authentication. If not provided, the default credentials will be used. |
Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsConfigMixin(StackComponentConfig):
"""Config mixin for Google Cloud Platform credentials.
Attributes:
project: GCP project name. If `None`, the project will be inferred from
the environment.
service_account_path: path to the service account credentials file to be
used for authentication. If not provided, the default credentials
will be used.
"""
project: Optional[str] = None
service_account_path: Optional[str] = None
GoogleCredentialsMixin (StackComponent)
StackComponent mixin to get Google Cloud Platform credentials.
Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsMixin(StackComponent):
"""StackComponent mixin to get Google Cloud Platform credentials."""
@property
def config(self) -> GoogleCredentialsConfigMixin:
"""Returns the `GoogleCredentialsConfigMixin` config.
Returns:
The configuration.
"""
return cast(GoogleCredentialsConfigMixin, self._config)
def _get_authentication(self) -> Tuple["Credentials", str]:
"""Get GCP credentials and the project ID associated with the credentials.
If `service_account_path` is provided, then the credentials will be
loaded from the file at that path. Otherwise, the default credentials
will be used.
Returns:
A tuple containing the credentials and the project ID associated to
the credentials.
Raises:
RuntimeError: If the linked connector returns an unexpected type of
credentials.
"""
from google.auth import default, load_credentials_from_file
from google.auth.credentials import Credentials
from zenml.integrations.gcp.service_connectors import (
GCPServiceConnector,
)
connector = self.get_connector()
if connector:
credentials = connector.connect()
if not isinstance(credentials, Credentials) or not isinstance(
connector, GCPServiceConnector
):
raise RuntimeError(
f"Expected google.auth.credentials.Credentials while "
"trying to use the linked connector, but got "
f"{type(credentials)}."
)
return credentials, connector.config.project_id
if self.config.service_account_path:
credentials, project_id = load_credentials_from_file(
self.config.service_account_path
)
else:
credentials, project_id = default()
if self.config.project and self.config.project != project_id:
logger.warning(
"Authenticated with project `%s`, but this %s is "
"configured to use the project `%s`.",
project_id,
self.type,
self.config.project,
)
# If the project was set in the configuration, use it. Otherwise, use
# the project that was used to authenticate.
project_id = self.config.project if self.config.project else project_id
return credentials, project_id
config: GoogleCredentialsConfigMixin
property
readonly
Returns the GoogleCredentialsConfigMixin
config.
Returns:
Type | Description |
---|---|
GoogleCredentialsConfigMixin |
The configuration. |
image_builders
special
Initialization for the GCP image builder.
gcp_image_builder
Google Cloud Builder image builder implementation.
GCPImageBuilder (BaseImageBuilder, GoogleCredentialsMixin)
Google Cloud Builder image builder implementation.
Source code in zenml/integrations/gcp/image_builders/gcp_image_builder.py
class GCPImageBuilder(BaseImageBuilder, GoogleCredentialsMixin):
"""Google Cloud Builder image builder implementation."""
@property
def config(self) -> GCPImageBuilderConfig:
"""The stack component configuration.
Returns:
The configuration.
"""
return cast(GCPImageBuilderConfig, self._config)
@property
def is_building_locally(self) -> bool:
"""Whether the image builder builds the images on the client machine.
Returns:
True if the image builder builds locally, False otherwise.
"""
return False
@property
def validator(self) -> Optional["StackValidator"]:
"""Validates the stack for the GCP Image Builder.
The GCP Image Builder requires a remote container registry to push the
image to, and a GCP Artifact Store to upload the build context, so
Cloud Build can access it.
Returns:
Stack validator.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
assert stack.container_registry
if (
stack.container_registry.flavor
!= ContainerRegistryFlavor.GCP.value
):
return False, (
"The GCP Image Builder requires a GCP container registry to "
"push the image to. Please update your stack to include a "
"GCP container registry and try again."
)
if stack.artifact_store.flavor != GCP_ARTIFACT_STORE_FLAVOR:
return False, (
"The GCP Image Builder requires a GCP Artifact Store to "
"upload the build context, so Cloud Build can access it."
"Please update your stack to include a GCP Artifact Store "
"and try again."
)
return True, ""
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_validate_remote_components,
)
def build(
self,
image_name: str,
build_context: "BuildContext",
docker_build_options: Dict[str, Any],
container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
"""Builds and pushes a Docker image.
Args:
image_name: Name of the image to build and push.
build_context: The build context to use for the image.
docker_build_options: Docker build options.
container_registry: Optional container registry to push to.
Returns:
The Docker image name with digest.
Raises:
RuntimeError: If no container registry is passed.
RuntimeError: If the Cloud Build build fails.
"""
if not container_registry:
raise RuntimeError(
"The GCP Image Builder requires a container registry to push "
"the image to. Please provide one and try again."
)
logger.info("Using Cloud Build to build image `%s`", image_name)
cloud_build_context = self._upload_build_context(
build_context=build_context,
parent_path_directory_name="cloud-build-contexts",
)
build = self._configure_cloud_build(
image_name=image_name,
cloud_build_context=cloud_build_context,
build_options=docker_build_options,
)
image_digest = self._run_cloud_build(build=build)
image_name_without_tag, _ = image_name.rsplit(":", 1)
image_name_with_digest = f"{image_name_without_tag}@{image_digest}"
return image_name_with_digest
def _configure_cloud_build(
self,
image_name: str,
cloud_build_context: str,
build_options: Dict[str, Any],
) -> cloudbuild_v1.Build:
"""Configures the build to be run to generate the Docker image.
Args:
image_name: The name of the image to build.
cloud_build_context: The path to the build context.
build_options: Docker build options.
Returns:
The build to run.
"""
url_parts = urlparse(cloud_build_context)
bucket = url_parts.netloc
object_path = url_parts.path.lstrip("/")
logger.info(
"Build context located in bucket `%s` and object path `%s`",
bucket,
object_path,
)
cloud_builder_image = self.config.cloud_builder_image
cloud_builder_network_option = f"--network={self.config.network}"
logger.info(
"Using Cloud Builder image `%s` to run the steps in the build. "
"Container will be attached to network using option `%s`.",
cloud_builder_image,
cloud_builder_network_option,
)
# Convert the docker_build_options dictionary to a list of strings
docker_build_args = []
for key, value in build_options.items():
if isinstance(value, list):
for val in value:
docker_build_args.extend([key, val])
else:
docker_build_args.extend([key, value])
return cloudbuild_v1.Build(
source=cloudbuild_v1.Source(
storage_source=cloudbuild_v1.StorageSource(
bucket=bucket, object=object_path
),
),
steps=[
{
"name": cloud_builder_image,
"args": [
"build",
cloud_builder_network_option,
"-t",
image_name,
".",
*build_options,
],
},
{
"name": cloud_builder_image,
"args": ["push", image_name],
},
],
images=[image_name],
timeout=f"{self.config.build_timeout}s",
)
def _run_cloud_build(self, build: cloudbuild_v1.Build) -> str:
"""Executes the Cloud Build run to build the Docker image.
Args:
build: The build to run.
Returns:
The Docker image repo digest.
Raises:
RuntimeError: If the Cloud Build run has failed.
"""
credentials, project_id = self._get_authentication()
client = cloudbuild_v1.CloudBuildClient(credentials=credentials)
operation = client.create_build(project_id=project_id, build=build)
log_url = operation.metadata.build.log_url
logger.info(
"Running Cloud Build to build the Docker image. Cloud Build logs: `%s`",
log_url,
)
result = operation.result(timeout=self.config.build_timeout)
if result.status != cloudbuild_v1.Build.Status.SUCCESS:
raise RuntimeError(
f"The Cloud Build run to build the Docker image has failed. More "
f"information can be found in the Cloud Build logs: {log_url}."
)
logger.info(
f"The Docker image has been built successfully. More information can "
f"be found in the Cloud Build logs: `{log_url}`."
)
image_digest: str = result.results.images[0].digest
return image_digest
config: GCPImageBuilderConfig
property
readonly
The stack component configuration.
Returns:
Type | Description |
---|---|
GCPImageBuilderConfig |
The configuration. |
is_building_locally: bool
property
readonly
Whether the image builder builds the images on the client machine.
Returns:
Type | Description |
---|---|
bool |
True if the image builder builds locally, False otherwise. |
validator: Optional[StackValidator]
property
readonly
Validates the stack for the GCP Image Builder.
The GCP Image Builder requires a remote container registry to push the image to, and a GCP Artifact Store to upload the build context, so Cloud Build can access it.
Returns:
Type | Description |
---|---|
Optional[StackValidator] |
Stack validator. |
build(self, image_name, build_context, docker_build_options, container_registry=None)
Builds and pushes a Docker image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the image to build and push. |
required |
build_context |
BuildContext |
The build context to use for the image. |
required |
docker_build_options |
Dict[str, Any] |
Docker build options. |
required |
container_registry |
Optional[BaseContainerRegistry] |
Optional container registry to push to. |
None |
Returns:
Type | Description |
---|---|
str |
The Docker image name with digest. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no container registry is passed. |
RuntimeError |
If the Cloud Build build fails. |
Source code in zenml/integrations/gcp/image_builders/gcp_image_builder.py
def build(
self,
image_name: str,
build_context: "BuildContext",
docker_build_options: Dict[str, Any],
container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
"""Builds and pushes a Docker image.
Args:
image_name: Name of the image to build and push.
build_context: The build context to use for the image.
docker_build_options: Docker build options.
container_registry: Optional container registry to push to.
Returns:
The Docker image name with digest.
Raises:
RuntimeError: If no container registry is passed.
RuntimeError: If the Cloud Build build fails.
"""
if not container_registry:
raise RuntimeError(
"The GCP Image Builder requires a container registry to push "
"the image to. Please provide one and try again."
)
logger.info("Using Cloud Build to build image `%s`", image_name)
cloud_build_context = self._upload_build_context(
build_context=build_context,
parent_path_directory_name="cloud-build-contexts",
)
build = self._configure_cloud_build(
image_name=image_name,
cloud_build_context=cloud_build_context,
build_options=docker_build_options,
)
image_digest = self._run_cloud_build(build=build)
image_name_without_tag, _ = image_name.rsplit(":", 1)
image_name_with_digest = f"{image_name_without_tag}@{image_digest}"
return image_name_with_digest
orchestrators
special
Initialization for the VertexAI orchestrator.
vertex_orchestrator
Implementation of the VertexAI orchestrator.
VertexOrchestrator (ContainerizedOrchestrator, GoogleCredentialsMixin)
Orchestrator responsible for running pipelines on Vertex AI.
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
"""Orchestrator responsible for running pipelines on Vertex AI."""
_pipeline_root: str
@property
def config(self) -> VertexOrchestratorConfig:
"""Returns the `VertexOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(VertexOrchestratorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Vertex orchestrator.
Returns:
The settings class.
"""
return VertexOrchestratorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Also validates that the artifact store is not local.
Returns:
A StackValidator instance.
"""
def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
"""Validates that all the stack components are not local.
Args:
stack: The stack to validate.
Returns:
A tuple of (is_valid, error_message).
"""
# Validate that the container registry is not local.
container_registry = stack.container_registry
if container_registry and container_registry.config.is_local:
return False, (
f"The Vertex orchestrator does not support local "
f"container registries. You should replace the component '"
f"{container_registry.name}' "
f"{container_registry.type.value} to a remote one."
)
# Validate that the rest of the components are not local.
for stack_comp in stack.components.values():
# For Forward compatibility a list of components is returned,
# but only the first item is relevant for now
# TODO: [server] make sure the ComponentModel actually has
# a local_path property or implement similar check
local_path = stack_comp.local_path
if not local_path:
continue
return False, (
f"The '{stack_comp.name}' {stack_comp.type.value} is a "
f"local stack component. The Vertex AI Pipelines "
f"orchestrator requires that all the components in the "
f"stack used to execute the pipeline have to be not local, "
f"because there is no way for Vertex to connect to your "
f"local machine. You should use a flavor of "
f"{stack_comp.type.value} other than '"
f"{stack_comp.flavor}'."
)
# If the `pipeline_root` has not been defined in the orchestrator
# configuration, and the artifact store is not a GCP artifact store,
# then raise an error.
if (
not self.config.pipeline_root
and stack.artifact_store.flavor != GCP_ARTIFACT_STORE_FLAVOR
):
return False, (
f"The attribute `pipeline_root` has not been set and it "
f"cannot be generated using the path of the artifact store "
f"because it is not a "
f"`zenml.integrations.gcp.artifact_store.GCPArtifactStore`."
f" To solve this issue, set the `pipeline_root` attribute "
f"manually executing the following command: "
f"`zenml orchestrator update {stack.orchestrator.name} "
f'--pipeline_root="<Cloud Storage URI>"`.'
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_stack_requirements,
)
@property
def root_directory(self) -> str:
"""Returns path to the root directory for files for this orchestrator.
Returns:
The path to the root directory for all files concerning this
orchestrator.
"""
return os.path.join(
get_global_config_directory(), "vertex", str(self.id)
)
@property
def pipeline_directory(self) -> str:
"""Returns path to directory where kubeflow pipelines files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
ValueError: If `cron_expression` is not in passed Schedule.
"""
if deployment.schedule:
if (
deployment.schedule.catchup
or deployment.schedule.start_time
or deployment.schedule.end_time
or deployment.schedule.interval_second
):
logger.warning(
"Vertex orchestrator only uses schedules with the "
"`cron_expression` property. All other properties "
"are ignored."
)
if deployment.schedule.cron_expression is None:
raise ValueError(
"Property `cron_expression` must be set when passing "
"schedule to a Vertex orchestrator."
)
def _configure_container_resources(
self,
container_op: dsl.ContainerOp,
resource_settings: "ResourceSettings",
node_selector_constraint: Optional[Tuple[str, str]] = None,
) -> None:
"""Adds resource requirements to the container.
Args:
container_op: The kubeflow container operation to configure.
resource_settings: The resource settings to use for this
container.
node_selector_constraint: Node selector constraint to apply to
the container.
"""
# Set optional CPU, RAM and GPU constraints for the pipeline
cpu_limit = resource_settings.cpu_count or self.config.cpu_limit
if cpu_limit is not None:
container_op = container_op.set_cpu_limit(str(cpu_limit))
memory_limit = (
resource_settings.memory[:-1]
if resource_settings.memory
else self.config.memory_limit
)
if memory_limit is not None:
container_op = container_op.set_memory_limit(memory_limit)
gpu_limit = (
resource_settings.gpu_count
if resource_settings.gpu_count is not None
else self.config.gpu_limit
)
if gpu_limit is not None and gpu_limit > 0:
container_op = container_op.set_gpu_limit(gpu_limit)
if node_selector_constraint:
constraint_label, value = node_selector_constraint
if not (
constraint_label
== GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
and gpu_limit == 0
):
container_op.add_node_selector_constraint(
constraint_label, value
)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Creates a KFP JSON pipeline.
# noqa: DAR402
This is an intermediary representation of the pipeline which is then
deployed to Vertex AI Pipelines service.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()` method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`). The function
`kfp.components.load_component_from_text` is used to create the
`ContainerOp`, because using the `dsl.ContainerOp` class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the
intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for
execution.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
ValueError: If the attribute `pipeline_root` is not set and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`. Also gets
raised if attempting to schedule pipeline run without using the
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# If the `pipeline_root` has not been defined in the orchestrator
# configuration,
# try to create it from the artifact store if it is a
# `GCPArtifactStore`.
if not self.config.pipeline_root:
artifact_store = stack.artifact_store
self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{deployment.pipeline_configuration.name}/{orchestrator_run_name}"
logger.info(
"The attribute `pipeline_root` has not been set in the "
"orchestrator configuration. One has been generated "
"automatically based on the path of the `GCPArtifactStore` "
"artifact store in the stack used to execute the pipeline. "
"The generated `pipeline_root` is `%s`.",
self._pipeline_root,
)
else:
self._pipeline_root = self.config.pipeline_root
def _construct_kfp_pipeline() -> None:
"""Create a `ContainerOp` for each step.
This should contain the name of the Docker image and configures the
entrypoint of the Docker image to run the step.
Additionally, this gives each `ContainerOp` information about its
direct downstream steps.
If this callable is passed to the `compile()` method of
`KFPV2Compiler` all `dsl.ContainerOp` instances will be
automatically added to a singular `dsl.Pipeline` instance.
"""
command = StepEntrypointConfiguration.get_entrypoint_command()
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.step_configurations.items():
image = self.get_image(
deployment=deployment, step_name=step_name
)
arguments = (
StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
)
# Create the `ContainerOp` for the step. Using the
# `dsl.ContainerOp`
# class directly is deprecated when using the Kubeflow SDK v2.
container_op = kfp.components.load_component_from_text(
f"""
name: {step_name}
implementation:
container:
image: {image}
command: {command + arguments}"""
)()
container_op.set_env_variable(
name=ENV_ZENML_VERTEX_RUN_ID,
value=dslv2.PIPELINE_JOB_NAME_PLACEHOLDER,
)
for key, value in environment.items():
container_op.set_env_variable(name=key, value=value)
# Set upstream tasks as a dependency of the current step
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
settings = cast(
VertexOrchestratorSettings,
self.get_settings(step),
)
if settings.pod_settings:
apply_pod_settings(
container_op=container_op,
settings=settings.pod_settings,
)
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
node_selector_constraint=settings.node_selector_constraint,
)
container_op.set_caching_options(enable_caching=False)
step_name_to_container_op[step_name] = container_op
# Save the generated pipeline to a file.
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory,
f"{orchestrator_run_name}.json",
)
# Compile the pipeline using the Kubeflow SDK V2 compiler that allows
# to generate a JSON representation of the pipeline that can be later
# upload to Vertex AI Pipelines service.
KFPV2Compiler().compile(
pipeline_func=_construct_kfp_pipeline,
package_path=pipeline_file_path,
pipeline_name=_clean_pipeline_name(
deployment.pipeline_configuration.name
),
)
logger.info(
"Writing Vertex workflow definition to `%s`.", pipeline_file_path
)
settings = cast(
VertexOrchestratorSettings, self.get_settings(deployment)
)
if deployment.schedule:
logger.info(
"Scheduling job using Google Cloud Scheduler and Google "
"Cloud Functions..."
)
self._upload_and_schedule_pipeline(
pipeline_name=deployment.pipeline_configuration.name,
run_name=orchestrator_run_name,
stack=stack,
schedule=deployment.schedule,
pipeline_file_path=pipeline_file_path,
settings=settings,
)
else:
logger.info("No schedule detected. Creating one-off vertex job...")
# Using the Google Cloud AIPlatform client, upload and execute the
# pipeline
# on the Vertex AI Pipelines service.
self._upload_and_run_pipeline(
pipeline_name=deployment.pipeline_configuration.name,
pipeline_file_path=pipeline_file_path,
run_name=orchestrator_run_name,
settings=settings,
)
def _upload_and_schedule_pipeline(
self,
pipeline_name: str,
run_name: str,
stack: "Stack",
schedule: "Schedule",
pipeline_file_path: str,
settings: VertexOrchestratorSettings,
) -> None:
"""Uploads and schedules pipeline on GCP.
Args:
pipeline_name: Name of the pipeline.
run_name: Orchestrator run name.
stack: The stack the pipeline will run on.
schedule: The schedule the pipeline will run on.
pipeline_file_path: Path of the JSON file containing the compiled
Kubeflow pipeline (compiled with Kubeflow SDK v2).
settings: Pipeline level settings for this orchestrator.
Raises:
ValueError: If the attribute `pipeline_root` is not set, and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`. Also
gets raised if attempting to schedule pipeline run without using
the `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
# First, do some validation
artifact_store = stack.artifact_store
if artifact_store.flavor != GCP_ARTIFACT_STORE_FLAVOR:
raise ValueError(
"Currently, the Vertex AI orchestrator only supports "
"scheduled runs in combination with an artifact store of "
f"flavor: {GCP_ARTIFACT_STORE_FLAVOR}. The current stacks "
f"artifact store is of flavor: {artifact_store.flavor}. "
"Please update your stack accordingly."
)
# Get the credentials that would be used to create resources.
credentials, project_id = self._get_authentication()
scheduler_service_account_email: Optional[str] = None
if self.config.scheduler_service_account:
scheduler_service_account_email = (
self.config.scheduler_service_account
)
elif hasattr(credentials, "signer_email"):
scheduler_service_account_email = credentials.signer_email
else:
scheduler_service_account_email = (
self.config.function_service_account
or self.config.workload_service_account
)
if not scheduler_service_account_email:
raise ValueError(
"A GCP service account is required to schedule a pipeline run. "
"The credentials used to authenticate with GCP do not have a "
"service account associated with them and a service account "
"was not configured in the `scheduler_service_account` field "
"of the orchestrator config. Please update your orchestrator "
"configuration or credentials accordingly."
)
# Copy over the scheduled pipeline to the artifact store
artifact_store_base_uri = f"{artifact_store.path.rstrip('/')}/vertex_scheduled_pipelines/{pipeline_name}/{run_name}"
artifact_store_pipeline_uri = (
f"{artifact_store_base_uri}/vertex_pipeline.json"
)
fileio.copy(pipeline_file_path, artifact_store_pipeline_uri)
logger.info(
"The scheduled pipeline representation has been "
"automatically copied to this path of the `GCPArtifactStore`: "
f"{artifact_store_pipeline_uri}",
)
# Create cloud function
function_uri = create_cloud_function(
directory_path=vertex_scheduler.__path__[0], # fixed path
upload_path=f"{artifact_store_base_uri}/code.zip",
project=project_id,
location=self.config.location,
function_name=run_name,
credentials=credentials,
function_service_account_email=self.config.function_service_account,
)
# Create the scheduler job
body = {
TEMPLATE_PATH: artifact_store_pipeline_uri,
JOB_ID: _clean_pipeline_name(pipeline_name),
PIPELINE_ROOT: self._pipeline_root,
PARAMETER_VALUES: None,
ENABLE_CACHING: False,
ENCRYPTION_SPEC_KEY_NAME: self.config.encryption_spec_key_name,
LABELS: settings.labels,
PROJECT: project_id,
LOCATION: self.config.location,
WORKLOAD_SERVICE_ACCOUNT: self.config.workload_service_account,
NETWORK: self.config.network,
}
create_scheduler_job(
project=project_id,
region=self.config.location,
http_uri=function_uri,
body=body,
schedule=str(schedule.cron_expression),
credentials=credentials,
service_account_email=scheduler_service_account_email,
)
def _upload_and_run_pipeline(
self,
pipeline_name: str,
pipeline_file_path: str,
run_name: str,
settings: VertexOrchestratorSettings,
) -> None:
"""Uploads and run the pipeline on the Vertex AI Pipelines service.
Args:
pipeline_name: Name of the pipeline.
pipeline_file_path: Path of the JSON file containing the compiled
Kubeflow pipeline (compiled with Kubeflow SDK v2).
run_name: Orchestrator run name.
settings: Pipeline level settings for this orchestrator.
"""
# We have to replace the hyphens in the run name with underscores
# and lower case the string, because the Vertex AI Pipelines service
# requires this format.
job_id = _clean_pipeline_name(run_name)
# Get the credentials that would be used to create the Vertex AI
# Pipelines
# job.
credentials, project_id = self._get_authentication()
# Instantiate the Vertex AI Pipelines job
run = aiplatform.PipelineJob(
display_name=pipeline_name,
template_path=pipeline_file_path,
job_id=job_id,
pipeline_root=self._pipeline_root,
parameter_values=None,
enable_caching=False,
encryption_spec_key_name=self.config.encryption_spec_key_name,
labels=settings.labels,
credentials=credentials,
project=project_id,
location=self.config.location,
)
logger.info(
"Submitting pipeline job with job_id `%s` to Vertex AI Pipelines "
"service.",
job_id,
)
# Submit the job to Vertex AI Pipelines service.
try:
if self.config.workload_service_account:
logger.info(
"The Vertex AI Pipelines job workload will be executed "
"using the `%s` "
"service account.",
self.config.workload_service_account,
)
if self.config.network:
logger.info(
"The Vertex AI Pipelines job will be peered with the `%s` "
"network.",
self.config.network,
)
run.submit(
service_account=self.config.workload_service_account,
network=self.config.network,
)
logger.info(
"View the Vertex AI Pipelines job at %s", run._dashboard_uri()
)
if settings.synchronous:
logger.info(
"Waiting for the Vertex AI Pipelines job to finish..."
)
run.wait()
except google_exceptions.ClientError as e:
logger.warning(
"Failed to create the Vertex AI Pipelines job: %s", e
)
except RuntimeError as e:
logger.error(
"The Vertex AI Pipelines job execution has failed: %s", e
)
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If the environment variable specifying the run id
is not set.
Returns:
The orchestrator run id.
"""
try:
return os.environ[ENV_ZENML_VERTEX_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_VERTEX_RUN_ID}."
)
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_url = (
f"https://console.cloud.google.com/vertex-ai/locations/"
f"{self.config.location}/pipelines/runs/"
f"{self.get_orchestrator_run_id()}"
)
if self.config.project:
run_url += f"?project={self.config.project}"
return {
METADATA_ORCHESTRATOR_URL: Uri(run_url),
}
config: VertexOrchestratorConfig
property
readonly
Returns the VertexOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
VertexOrchestratorConfig |
The configuration. |
pipeline_directory: str
property
readonly
Returns path to directory where kubeflow pipelines files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Returns path to the root directory for files for this orchestrator.
Returns:
Type | Description |
---|---|
str |
The path to the root directory for all files concerning this orchestrator. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Vertex orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Also validates that the artifact store is not local.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A StackValidator instance. |
get_orchestrator_run_id(self)
Returns the active orchestrator run id.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the environment variable specifying the run id is not set. |
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If the environment variable specifying the run id
is not set.
Returns:
The orchestrator run id.
"""
try:
return os.environ[ENV_ZENML_VERTEX_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_VERTEX_RUN_ID}."
)
get_pipeline_run_metadata(self, run_id)
Get general component-specific metadata for a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
The ID of the pipeline run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_url = (
f"https://console.cloud.google.com/vertex-ai/locations/"
f"{self.config.location}/pipelines/runs/"
f"{self.get_orchestrator_run_id()}"
)
if self.config.project:
run_url += f"?project={self.config.project}"
return {
METADATA_ORCHESTRATOR_URL: Uri(run_url),
}
prepare_or_run_pipeline(self, deployment, stack, environment)
Creates a KFP JSON pipeline.
noqa: DAR402
This is an intermediary representation of the pipeline which is then deployed to Vertex AI Pipelines service.
How it works:
Before this method is called the prepare_pipeline_deployment()
method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (_construct_kfp_pipeline
). The function
kfp.components.load_component_from_text
is used to create the
ContainerOp
, because using the dsl.ContainerOp
class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponseModel |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the attribute |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Creates a KFP JSON pipeline.
# noqa: DAR402
This is an intermediary representation of the pipeline which is then
deployed to Vertex AI Pipelines service.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()` method
builds a Docker image that contains the code for the pipeline, all steps
the context around these files.
Based on this Docker image a callable is created which builds
container_ops for each step (`_construct_kfp_pipeline`). The function
`kfp.components.load_component_from_text` is used to create the
`ContainerOp`, because using the `dsl.ContainerOp` class directly is
deprecated when using the Kubeflow SDK v2. The step entrypoint command
with the entrypoint arguments is the command that will be executed by
the container created using the previously created Docker image.
This callable is then compiled into a JSON file that is used as the
intermediary representation of the Kubeflow pipeline.
This file then is submitted to the Vertex AI Pipelines service for
execution.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
ValueError: If the attribute `pipeline_root` is not set and it
can be not generated using the path of the artifact store in the
stack because it is not a
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`. Also gets
raised if attempting to schedule pipeline run without using the
`zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
"""
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# If the `pipeline_root` has not been defined in the orchestrator
# configuration,
# try to create it from the artifact store if it is a
# `GCPArtifactStore`.
if not self.config.pipeline_root:
artifact_store = stack.artifact_store
self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{deployment.pipeline_configuration.name}/{orchestrator_run_name}"
logger.info(
"The attribute `pipeline_root` has not been set in the "
"orchestrator configuration. One has been generated "
"automatically based on the path of the `GCPArtifactStore` "
"artifact store in the stack used to execute the pipeline. "
"The generated `pipeline_root` is `%s`.",
self._pipeline_root,
)
else:
self._pipeline_root = self.config.pipeline_root
def _construct_kfp_pipeline() -> None:
"""Create a `ContainerOp` for each step.
This should contain the name of the Docker image and configures the
entrypoint of the Docker image to run the step.
Additionally, this gives each `ContainerOp` information about its
direct downstream steps.
If this callable is passed to the `compile()` method of
`KFPV2Compiler` all `dsl.ContainerOp` instances will be
automatically added to a singular `dsl.Pipeline` instance.
"""
command = StepEntrypointConfiguration.get_entrypoint_command()
step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}
for step_name, step in deployment.step_configurations.items():
image = self.get_image(
deployment=deployment, step_name=step_name
)
arguments = (
StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
)
# Create the `ContainerOp` for the step. Using the
# `dsl.ContainerOp`
# class directly is deprecated when using the Kubeflow SDK v2.
container_op = kfp.components.load_component_from_text(
f"""
name: {step_name}
implementation:
container:
image: {image}
command: {command + arguments}"""
)()
container_op.set_env_variable(
name=ENV_ZENML_VERTEX_RUN_ID,
value=dslv2.PIPELINE_JOB_NAME_PLACEHOLDER,
)
for key, value in environment.items():
container_op.set_env_variable(name=key, value=value)
# Set upstream tasks as a dependency of the current step
for upstream_step_name in step.spec.upstream_steps:
upstream_container_op = step_name_to_container_op[
upstream_step_name
]
container_op.after(upstream_container_op)
settings = cast(
VertexOrchestratorSettings,
self.get_settings(step),
)
if settings.pod_settings:
apply_pod_settings(
container_op=container_op,
settings=settings.pod_settings,
)
self._configure_container_resources(
container_op=container_op,
resource_settings=step.config.resource_settings,
node_selector_constraint=settings.node_selector_constraint,
)
container_op.set_caching_options(enable_caching=False)
step_name_to_container_op[step_name] = container_op
# Save the generated pipeline to a file.
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory,
f"{orchestrator_run_name}.json",
)
# Compile the pipeline using the Kubeflow SDK V2 compiler that allows
# to generate a JSON representation of the pipeline that can be later
# upload to Vertex AI Pipelines service.
KFPV2Compiler().compile(
pipeline_func=_construct_kfp_pipeline,
package_path=pipeline_file_path,
pipeline_name=_clean_pipeline_name(
deployment.pipeline_configuration.name
),
)
logger.info(
"Writing Vertex workflow definition to `%s`.", pipeline_file_path
)
settings = cast(
VertexOrchestratorSettings, self.get_settings(deployment)
)
if deployment.schedule:
logger.info(
"Scheduling job using Google Cloud Scheduler and Google "
"Cloud Functions..."
)
self._upload_and_schedule_pipeline(
pipeline_name=deployment.pipeline_configuration.name,
run_name=orchestrator_run_name,
stack=stack,
schedule=deployment.schedule,
pipeline_file_path=pipeline_file_path,
settings=settings,
)
else:
logger.info("No schedule detected. Creating one-off vertex job...")
# Using the Google Cloud AIPlatform client, upload and execute the
# pipeline
# on the Vertex AI Pipelines service.
self._upload_and_run_pipeline(
pipeline_name=deployment.pipeline_configuration.name,
pipeline_file_path=pipeline_file_path,
run_name=orchestrator_run_name,
settings=settings,
)
prepare_pipeline_deployment(self, deployment, stack)
Build a Docker image and push it to the container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponseModel |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If |
Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeploymentResponseModel",
stack: "Stack",
) -> None:
"""Build a Docker image and push it to the container registry.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
Raises:
ValueError: If `cron_expression` is not in passed Schedule.
"""
if deployment.schedule:
if (
deployment.schedule.catchup
or deployment.schedule.start_time
or deployment.schedule.end_time
or deployment.schedule.interval_second
):
logger.warning(
"Vertex orchestrator only uses schedules with the "
"`cron_expression` property. All other properties "
"are ignored."
)
if deployment.schedule.cron_expression is None:
raise ValueError(
"Property `cron_expression` must be set when passing "
"schedule to a Vertex orchestrator."
)
vertex_scheduler
special
Loading the vertex scheduler package.
main
Entrypoint for the scheduled vertex job.
trigger_vertex_job(request)
Processes the incoming HTTP request.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
HTTP request object. |
required |
Returns:
Type | Description |
---|---|
str |
The response text or any set of values that can be turned into a Response. |
Source code in zenml/integrations/gcp/orchestrators/vertex_scheduler/main.py
def trigger_vertex_job(request: "Request") -> str:
"""Processes the incoming HTTP request.
Args:
request: HTTP request object.
Returns:
The response text or any set of values that can be turned into a Response.
"""
# decode http request payload and translate into JSON object
request_str = request.data.decode("utf-8")
request_json = json.loads(request_str)
display_name = f"{request_json[JOB_ID]}-scheduled-{random.Random().getrandbits(32):08x}"
run = aiplatform.PipelineJob(
display_name=display_name,
template_path=request_json[TEMPLATE_PATH],
job_id=display_name,
pipeline_root=request_json[PIPELINE_ROOT],
parameter_values=request_json[PARAMETER_VALUES],
enable_caching=request_json[ENABLE_CACHING],
encryption_spec_key_name=request_json[ENCRYPTION_SPEC_KEY_NAME],
labels=request_json[LABELS],
project=request_json[PROJECT],
location=request_json[LOCATION],
)
workload_service_account = request_json[WORKLOAD_SERVICE_ACCOUNT]
network = request_json[NETWORK]
if workload_service_account:
logging.info(
"The Vertex AI Pipelines job workload will be executed "
"using the `%s` "
"service account.",
workload_service_account,
)
if network:
logging.info(
"The Vertex AI Pipelines job will be peered with the `%s` "
"network.",
network,
)
run.submit(
service_account=workload_service_account,
network=network,
)
return f"{display_name} submitted!"
secrets_manager
special
ZenML integration for GCP Secrets Manager.
The GCP Secrets Manager allows your pipeline to directly access the GCP secrets manager and use the secrets within during runtime.
gcp_secrets_manager
Implementation of the GCP Secrets Manager.
GCPSecretsManager (BaseSecretsManager)
Class to interact with the GCP secrets manager.
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
class GCPSecretsManager(BaseSecretsManager):
"""Class to interact with the GCP secrets manager."""
CLIENT: ClassVar[Any] = None
@property
def config(self) -> GCPSecretsManagerConfig:
"""Returns the `GCPSecretsManagerConfig` config.
Returns:
The configuration.
"""
return cast(GCPSecretsManagerConfig, self._config)
@classmethod
def _ensure_client_connected(cls) -> None:
if cls.CLIENT is None:
cls.CLIENT = secretmanager.SecretManagerServiceClient()
@property
def parent_name(self) -> str:
"""Construct the GCP parent path to the secret manager.
Returns:
The parent path to the secret manager
"""
return f"projects/{self.config.project_id}"
def _convert_secret_content(
self, secret: BaseSecretSchema
) -> Dict[str, str]:
"""Convert the secret content into a Google compatible representation.
This method implements two currently supported modes of adapting between
the naming schemas used for ZenML secrets and Google secrets:
* for a scoped Secrets Manager, a Google secret is created for each
ZenML secret with a name that reflects the ZenML secret name and scope
and a value that contains all its key-value pairs in JSON format.
* for an unscoped (i.e. legacy) Secrets Manager, this method creates
multiple Google secret entries for a single ZenML secret by adding the
secret name to the key name of each secret key-value pair. This allows
using the same key across multiple secrets. This is only kept for
backwards compatibility and will be removed some time in the future.
Args:
secret: The ZenML secret
Returns:
A dictionary with the Google secret name as key and the secret
contents as value.
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy per-key secret mapping
return {f"{secret.name}_{k}": v for k, v in secret.content.items()}
return {
self._get_scoped_secret_name(
secret.name, separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR
): json.dumps(secret_to_dict(secret)),
}
def _get_secret_labels(
self, secret: BaseSecretSchema
) -> List[Tuple[str, str]]:
"""Return a list of Google secret label values for a given secret.
Args:
secret: the secret object
Returns:
A list of Google secret label values
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy per-key secret labels
return [
(ZENML_GROUP_KEY, secret.name),
(ZENML_SCHEMA_NAME, secret.TYPE),
]
metadata = self._get_secret_metadata(secret)
return list(metadata.items())
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> str:
"""Return a Google filter expression for the entire scope or just a scoped secret.
These filters can be used when querying the Google Secrets Manager
for all secrets or for a single secret available in the configured
scope (see https://cloud.google.com/secret-manager/docs/filtering).
Args:
secret_name: Optional secret name to include in the scope metadata.
Returns:
Google filter expression uniquely identifying all secrets
or a named secret within the configured scope.
"""
if self.config.scope == SecretsManagerScope.NONE:
# legacy per-key secret label filters
if secret_name:
return f"labels.{ZENML_GROUP_KEY}={secret_name}"
else:
return f"labels.{ZENML_GROUP_KEY}:*"
metadata = self._get_secret_scope_metadata(secret_name)
filters = [f"labels.{label}={v}" for (label, v) in metadata.items()]
if secret_name:
filters.append(f"name:{secret_name}")
return " AND ".join(filters)
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected()
set_of_secrets = set()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
if self.config.scope == SecretsManagerScope.NONE:
name = secret.labels[ZENML_GROUP_KEY]
else:
name = secret.labels[ZENML_SECRET_NAME_LABEL]
# filter by secret name, if one was given
if name and (not secret_name or name == secret_name):
set_of_secrets.add(name)
return list(set_of_secrets)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
gcp_secret = self.CLIENT.create_secret(
request={
"parent": self.parent_name,
"secret_id": k,
"secret": {
"replication": {"automatic": {}},
"labels": self._get_secret_labels(secret),
},
}
)
logger.debug("Created empty secret: %s", gcp_secret.name)
self.CLIENT.add_secret_version(
request={
"parent": gcp_secret.name,
"payload": {"data": str(v).encode()},
}
)
logger.debug("Added value to secret.")
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
zenml_secret: Optional[BaseSecretSchema] = None
if self.config.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Google secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
# List all secrets.
for google_secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
secret_version_name = google_secret.name + "/versions/latest"
response = self.CLIENT.access_secret_version(
request={"name": secret_version_name}
)
secret_value = response.payload.data.decode("UTF-8")
secret_key = remove_group_name_from_key(
google_secret.name.split("/")[-1], secret_name
)
secret_contents[secret_key] = secret_value
zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]
if not secret_contents:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Google secrets
google_secret_name = self.CLIENT.secret_path(
self.config.project_id,
self._get_scoped_secret_name(
secret_name,
separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR,
),
)
try:
# fetch the latest secret version
google_secret = self.CLIENT.get_secret(name=google_secret_name)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
# make sure the secret has the correct scope labels to filter out
# unscoped secrets with similar names
scope_labels = self._get_secret_scope_metadata(secret_name)
# all scope labels need to be included in the google secret labels,
# otherwise the secret does not belong to the current scope
if not scope_labels.items() <= google_secret.labels.items():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
try:
# fetch the latest secret version
response = self.CLIENT.access_secret_version(
name=f"{google_secret_name}/versions/latest"
)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_value = response.payload.data.decode("UTF-8")
zenml_secret = secret_from_dict(
json.loads(secret_value), secret_name=secret_name
)
return zenml_secret
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
google_secret_name = self.CLIENT.secret_path(
self.config.project_id, k
)
payload = {"data": str(v).encode()}
self.CLIENT.add_secret_version(
request={"parent": google_secret_name, "payload": payload}
)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
# Go through all gcp secrets and delete the ones with the secret_name
# as label.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
self.CLIENT.delete_secret(request={"name": secret.name})
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(),
}
):
logger.info(f"Deleting Google secret {secret.name}")
self.CLIENT.delete_secret(request={"name": secret.name})
config: GCPSecretsManagerConfig
property
readonly
Returns the GCPSecretsManagerConfig
config.
Returns:
Type | Description |
---|---|
GCPSecretsManagerConfig |
The configuration. |
parent_name: str
property
readonly
Construct the GCP parent path to the secret manager.
Returns:
Type | Description |
---|---|
str |
The parent path to the secret manager |
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected()
# List all secrets.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(),
}
):
logger.info(f"Deleting Google secret {secret.name}")
self.CLIENT.delete_secret(request={"name": secret.name})
delete_secret(self, secret_name)
Delete an existing secret by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret no longer exists |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
# Go through all gcp secrets and delete the ones with the secret_name
# as label.
for secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
self.CLIENT.delete_secret(request={"name": secret.name})
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Get a secret by its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret_name)
self._ensure_client_connected()
zenml_secret: Optional[BaseSecretSchema] = None
if self.config.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Google secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
# List all secrets.
for google_secret in self.CLIENT.list_secrets(
request={
"parent": self.parent_name,
"filter": self._get_secret_scope_filters(secret_name),
}
):
secret_version_name = google_secret.name + "/versions/latest"
response = self.CLIENT.access_secret_version(
request={"name": secret_version_name}
)
secret_value = response.payload.data.decode("UTF-8")
secret_key = remove_group_name_from_key(
google_secret.name.split("/")[-1], secret_name
)
secret_contents[secret_key] = secret_value
zenml_schema_name = google_secret.labels[ZENML_SCHEMA_NAME]
if not secret_contents:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Google secrets
google_secret_name = self.CLIENT.secret_path(
self.config.project_id,
self._get_scoped_secret_name(
secret_name,
separator=ZENML_GCP_SECRET_SCOPE_PATH_SEPARATOR,
),
)
try:
# fetch the latest secret version
google_secret = self.CLIENT.get_secret(name=google_secret_name)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
# make sure the secret has the correct scope labels to filter out
# unscoped secrets with similar names
scope_labels = self._get_secret_scope_metadata(secret_name)
# all scope labels need to be included in the google secret labels,
# otherwise the secret does not belong to the current scope
if not scope_labels.items() <= google_secret.labels.items():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
try:
# fetch the latest secret version
response = self.CLIENT.access_secret_version(
name=f"{google_secret_name}/versions/latest"
)
except google_exceptions.NotFound:
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
secret_value = response.payload.data.decode("UTF-8")
zenml_secret = secret_from_dict(
json.loads(secret_value), secret_name=secret_name
)
return zenml_secret
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
gcp_secret = self.CLIENT.create_secret(
request={
"parent": self.parent_name,
"secret_id": k,
"secret": {
"replication": {"automatic": {}},
"labels": self._get_secret_labels(secret),
},
}
)
logger.debug("Created empty secret: %s", gcp_secret.name)
self.CLIENT.add_secret_version(
request={
"parent": gcp_secret.name,
"payload": {"data": str(v).encode()},
}
)
logger.debug("Added value to secret.")
update_secret(self, secret)
Update an existing secret by creating new versions of the existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
validate_gcp_secret_name_or_namespace(secret.name)
self._ensure_client_connected()
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
adjusted_content = self._convert_secret_content(secret)
for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
google_secret_name = self.CLIENT.secret_path(
self.config.project_id, k
)
payload = {"data": str(v).encode()}
self.CLIENT.add_secret_version(
request={"parent": google_secret_name, "payload": payload}
)
remove_group_name_from_key(combined_key_name, group_name)
Removes the secret group name from the secret key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
combined_key_name |
str |
Full name as it is within the gcp secrets manager |
required |
group_name |
str |
Group name (the ZenML Secret name) |
required |
Returns:
Type | Description |
---|---|
str |
The cleaned key |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the group name is not found in the key |
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
"""Removes the secret group name from the secret key.
Args:
combined_key_name: Full name as it is within the gcp secrets manager
group_name: Group name (the ZenML Secret name)
Returns:
The cleaned key
Raises:
RuntimeError: If the group name is not found in the key
"""
if combined_key_name.startswith(group_name + "_"):
return combined_key_name[len(group_name + "_") :]
else:
raise RuntimeError(
f"Key-name `{combined_key_name}` does not have the "
f"prefix `{group_name}`. Key could not be "
f"extracted."
)
service_connectors
special
ZenML GCP Service Connector.
gcp_service_connector
GCP Service Connector.
The GCP Service Connector implements various authentication methods for GCP services:
- Explicit GCP service account key
GCPAuthenticationMethods (StrEnum)
GCP Authentication methods.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPAuthenticationMethods(StrEnum):
"""GCP Authentication methods."""
IMPLICIT = "implicit"
USER_ACCOUNT = "user-account"
SERVICE_ACCOUNT = "service-account"
OAUTH2_TOKEN = "oauth2-token"
IMPERSONATION = "impersonation"
GCPBaseConfig (AuthenticationConfig)
pydantic-model
GCP base configuration.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPBaseConfig(AuthenticationConfig):
"""GCP base configuration."""
project_id: str = Field(
title="GCP Project ID where the target resource is located.",
)
GCPOAuth2Token (AuthenticationConfig)
pydantic-model
GCP OAuth 2.0 token credentials.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPOAuth2Token(AuthenticationConfig):
"""GCP OAuth 2.0 token credentials."""
token: SecretStr = Field(
title="GCP OAuth 2.0 Token",
)
GCPOAuth2TokenConfig (GCPBaseConfig, GCPOAuth2Token)
pydantic-model
GCP OAuth 2.0 configuration.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPOAuth2TokenConfig(GCPBaseConfig, GCPOAuth2Token):
"""GCP OAuth 2.0 configuration."""
service_account_email: Optional[str] = Field(
default=None,
title="GCP Service Account Email",
description="The email address of the service account that signed the "
"token. If not provided, the token is assumed to be issued for a user "
"account.",
)
service_account_email: str
pydantic-field
The email address of the service account that signed the token. If not provided, the token is assumed to be issued for a user account.
GCPServiceAccountConfig (GCPBaseConfig, GCPServiceAccountCredentials)
pydantic-model
GCP service account configuration.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials):
"""GCP service account configuration."""
GCPServiceAccountCredentials (AuthenticationConfig)
pydantic-model
GCP service account credentials.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceAccountCredentials(AuthenticationConfig):
"""GCP service account credentials."""
service_account_json: SecretStr = Field(
title="GCP Service Account Key JSON",
)
generate_temporary_tokens: bool = Field(
default=True,
title="Generate temporary OAuth 2.0 tokens",
description="Whether to generate temporary OAuth 2.0 tokens from the "
"service account key JSON. If set to False, the connector will "
"distribute the service account key JSON to clients instead.",
)
@validator("service_account_json")
def validate_service_account_json(cls, v: SecretStr) -> SecretStr:
"""Validate the service account credentials JSON.
Args:
v: The service account credentials JSON.
Returns:
The validated service account credentials JSON.
Raises:
ValueError: If the service account credentials JSON is invalid.
"""
try:
service_account_info = json.loads(v.get_secret_value())
except json.JSONDecodeError as e:
raise ValueError(
f"GCP service account credentials is not a valid JSON: {e}"
)
# Check that all fields are present
required_fields = [
"type",
"project_id",
"private_key_id",
"private_key",
"client_email",
"client_id",
"auth_uri",
"token_uri",
"auth_provider_x509_cert_url",
"client_x509_cert_url",
]
# Compute missing fields
missing_fields = set(required_fields) - set(
service_account_info.keys()
)
if missing_fields:
raise ValueError(
f"GCP service account credentials JSON is missing required "
f'fields: {", ".join(list(missing_fields))}'
)
if service_account_info["type"] != "service_account":
raise ValueError(
"The JSON does not contain GCP service account credentials. "
f'The "type" field is set to {service_account_info["type"]} '
"instead of 'service_account'."
)
return v
generate_temporary_tokens: bool
pydantic-field
Whether to generate temporary OAuth 2.0 tokens from the service account key JSON. If set to False, the connector will distribute the service account key JSON to clients instead.
validate_service_account_json(v)
classmethod
Validate the service account credentials JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
SecretStr |
The service account credentials JSON. |
required |
Returns:
Type | Description |
---|---|
SecretStr |
The validated service account credentials JSON. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the service account credentials JSON is invalid. |
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
@validator("service_account_json")
def validate_service_account_json(cls, v: SecretStr) -> SecretStr:
"""Validate the service account credentials JSON.
Args:
v: The service account credentials JSON.
Returns:
The validated service account credentials JSON.
Raises:
ValueError: If the service account credentials JSON is invalid.
"""
try:
service_account_info = json.loads(v.get_secret_value())
except json.JSONDecodeError as e:
raise ValueError(
f"GCP service account credentials is not a valid JSON: {e}"
)
# Check that all fields are present
required_fields = [
"type",
"project_id",
"private_key_id",
"private_key",
"client_email",
"client_id",
"auth_uri",
"token_uri",
"auth_provider_x509_cert_url",
"client_x509_cert_url",
]
# Compute missing fields
missing_fields = set(required_fields) - set(
service_account_info.keys()
)
if missing_fields:
raise ValueError(
f"GCP service account credentials JSON is missing required "
f'fields: {", ".join(list(missing_fields))}'
)
if service_account_info["type"] != "service_account":
raise ValueError(
"The JSON does not contain GCP service account credentials. "
f'The "type" field is set to {service_account_info["type"]} '
"instead of 'service_account'."
)
return v
GCPServiceAccountImpersonationConfig (GCPServiceAccountConfig)
pydantic-model
GCP service account impersonation configuration.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceAccountImpersonationConfig(GCPServiceAccountConfig):
"""GCP service account impersonation configuration."""
target_principal: str = Field(
title="GCP Service Account Email to impersonate",
)
GCPServiceConnector (ServiceConnector)
pydantic-model
GCP service connector.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceConnector(ServiceConnector):
"""GCP service connector."""
config: GCPBaseConfig
_session_cache: Dict[
Tuple[str, Optional[str], Optional[str]],
Tuple[
gcp_credentials.Credentials,
Optional[datetime.datetime],
],
] = {}
@classmethod
def _get_connector_type(cls) -> ServiceConnectorTypeModel:
"""Get the service connector type specification.
Returns:
The service connector type specification.
"""
return GCP_SERVICE_CONNECTOR_TYPE_SPEC
def get_session(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[gcp_service_account.Credentials, Optional[datetime.datetime]]:
"""Get a GCP session object with credentials for the specified resource.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to get credentials for.
resource_id: The resource ID to get credentials for.
Returns:
GCP session with credentials for the specified resource and its
expiration timestamp, if applicable.
"""
# We maintain a cache of all sessions to avoid re-authenticating
# multiple times for the same resource
key = (auth_method, resource_type, resource_id)
if key in self._session_cache:
session, expires_at = self._session_cache[key]
if expires_at is None:
return session, None
# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
if expires_at > now:
return session, expires_at
logger.debug(
f"Creating GCP authentication session for auth method "
f"'{auth_method}', resource type '{resource_type}' and resource ID "
f"'{resource_id}'..."
)
session, expires_at = self._authenticate(
auth_method, resource_type, resource_id
)
self._session_cache[key] = (session, expires_at)
return session, expires_at
@classmethod
def _get_scopes(
cls,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> List[str]:
"""Get the OAuth 2.0 scopes to use for the specified resource type.
Args:
resource_type: The resource type to get scopes for.
resource_id: The resource ID to get scopes for.
Returns:
OAuth 2.0 scopes to use for the specified resource type.
"""
return [
"https://www.googleapis.com/auth/cloud-platform",
]
def _authenticate(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[gcp_service_account.Credentials, Optional[datetime.datetime]]:
"""Authenticate to GCP and return a session with credentials.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to authenticate for.
resource_id: The resource ID to authenticate for.
Returns:
GCP OAuth 2.0 credentials and their expiration time if applicable.
Raises:
AuthorizationException: If the authentication fails.
"""
cfg = self.config
scopes = self._get_scopes(resource_type, resource_id)
expires_at: Optional[datetime.datetime] = None
if auth_method == GCPAuthenticationMethods.IMPLICIT:
self._check_implicit_auth_method_allowed()
# Determine the credentials from the environment
# Override the project ID if specified in the config
credentials, project_id = google.auth.default(
scopes=scopes,
)
elif auth_method == GCPAuthenticationMethods.OAUTH2_TOKEN:
assert isinstance(cfg, GCPOAuth2TokenConfig)
expires_at = self.expires_at
if expires_at:
# Remove the UTC timezone
expires_at = expires_at.replace(tzinfo=None)
credentials = gcp_credentials.Credentials(
token=cfg.token.get_secret_value(),
expiry=expires_at,
scopes=scopes,
)
if cfg.service_account_email:
credentials.signer_email = cfg.service_account_email
else:
if auth_method == GCPAuthenticationMethods.USER_ACCOUNT:
assert isinstance(cfg, GCPUserAccountConfig)
credentials = (
gcp_credentials.Credentials.from_authorized_user_info(
json.loads(cfg.user_account_json.get_secret_value()),
scopes=scopes,
)
)
else:
assert isinstance(cfg, GCPServiceAccountConfig)
credentials = (
gcp_service_account.Credentials.from_service_account_info(
json.loads(
cfg.service_account_json.get_secret_value()
),
scopes=scopes,
)
)
if auth_method == GCPAuthenticationMethods.IMPERSONATION:
assert isinstance(cfg, GCPServiceAccountImpersonationConfig)
try:
credentials = gcp_impersonated_credentials.Credentials(
source_credentials=credentials,
target_principal=cfg.target_principal,
target_scopes=scopes,
lifetime=self.expiration_seconds,
)
except google.auth.exceptions.GoogleAuthError as e:
raise AuthorizationException(
f"Failed to impersonate service account "
f"'{cfg.target_principal}': {e}"
)
if not credentials.valid:
try:
with requests.Session() as session:
req = Request(session)
credentials.refresh(req)
except google.auth.exceptions.GoogleAuthError as e:
raise AuthorizationException(
f"Could not fetch GCP OAuth2 token: {e}"
)
if credentials.expiry:
# Add the UTC timezone to the expiration time
expires_at = credentials.expiry.replace(
tzinfo=datetime.timezone.utc
)
return credentials, expires_at
def _parse_gcs_resource_id(self, resource_id: str) -> str:
"""Validate and convert an GCS resource ID to an GCS bucket name.
Args:
resource_id: The resource ID to convert.
Returns:
The GCS bucket name.
Raises:
ValueError: If the provided resource ID is not a valid GCS bucket
name or URI.
"""
# The resource ID could mean different things:
#
# - an GCS bucket URI
# - the GCS bucket name
#
# We need to extract the bucket name from the provided resource ID
bucket_name: Optional[str] = None
if re.match(
r"^gs://[a-z0-9][a-z0-9_-]{1,61}[a-z0-9](/.*)*$",
resource_id,
):
# The resource ID is an GCS bucket URI
bucket_name = resource_id.split("/")[2]
elif re.match(
r"^[a-z0-9][a-z0-9_-]{1,61}[a-z0-9]$",
resource_id,
):
# The resource ID is the GCS bucket name
bucket_name = resource_id
else:
raise ValueError(
f"Invalid resource ID for an GCS bucket: {resource_id}. "
f"Supported formats are:\n"
f"GCS bucket URI: gcs://<bucket-name>\n"
f"GCS bucket name: <bucket-name>"
)
return bucket_name
def _parse_gcr_resource_id(
self,
resource_id: str,
) -> str:
"""Validate and convert an GCR resource ID to an GCR registry ID.
Args:
resource_id: The resource ID to convert.
Returns:
The GCR registry ID.
Raises:
ValueError: If the provided resource ID is not a valid GCR
repository URI.
"""
# The resource ID could mean different things:
#
# - an GCR repository URI
#
# We need to extract the project ID and registry ID from
# the provided resource ID
config_project_id = self.config.project_id
project_id: Optional[str] = None
# A GCR repository URI uses one of several hostnames (gcr.io, us.gcr.io,
# eu.gcr.io, asia.gcr.io etc.) and the project ID is the first part of
# the URL path
if re.match(
r"^(https://)?([a-z]+.)*gcr.io/[a-z0-9-]+(/.+)*$",
resource_id,
):
# The resource ID is a GCR repository URI
if resource_id.startswith("https://"):
project_id = resource_id.split("/")[3]
else:
project_id = resource_id.split("/")[1]
else:
raise ValueError(
f"Invalid resource ID for a GCR registry: {resource_id}. "
f"Supported formats are:\n"
f"GCR repository URI: [https://][us.|eu.|asia.]gcr.io/<project-id>[/<repository-name>]"
)
# If the connector is configured with a project and the resource ID
# is an GCR repository URI that specifies a different project,
# we raise an error
if project_id and project_id != config_project_id:
raise ValueError(
f"The GCP project for the {resource_id} GCR repository "
f"'{project_id}' does not match the project configured in "
f"the connector: '{config_project_id}'."
)
return f"gcr.io/{project_id}"
def _parse_gke_resource_id(self, resource_id: str) -> str:
"""Validate and convert an GKE resource ID to a GKE cluster name.
Args:
resource_id: The resource ID to convert.
Returns:
The GKE cluster name.
Raises:
ValueError: If the provided resource ID is not a valid GKE cluster
name.
"""
if re.match(
r"^[a-z0-9]+[a-z0-9_-]*$",
resource_id,
):
# Assume the resource ID is an GKE cluster name
cluster_name = resource_id
else:
raise ValueError(
f"Invalid resource ID for a GKE cluster: {resource_id}. "
f"Supported formats are:\n"
f"GKE cluster name: <cluster-name>"
)
return cluster_name
def _canonical_resource_id(
self, resource_type: str, resource_id: str
) -> str:
"""Convert a resource ID to its canonical form.
Args:
resource_type: The resource type to canonicalize.
resource_id: The resource ID to canonicalize.
Returns:
The canonical resource ID.
"""
if resource_type == GCS_RESOURCE_TYPE:
bucket = self._parse_gcs_resource_id(resource_id)
return f"gs://{bucket}"
elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
cluster_name = self._parse_gke_resource_id(resource_id)
return cluster_name
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
registry_id = self._parse_gcr_resource_id(
resource_id,
)
return registry_id
else:
return resource_id
def _get_default_resource_id(self, resource_type: str) -> str:
"""Get the default resource ID for a resource type.
Args:
resource_type: The type of the resource to get a default resource ID
for. Only called with resource types that do not support
multiple instances.
Returns:
The default resource ID for the resource type.
Raises:
RuntimeError: If the GCR registry ID (GCP account ID)
cannot be retrieved from GCP because the connector is not
authorized.
"""
if resource_type == GCP_RESOURCE_TYPE:
return self.config.project_id
elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
return f"gcr.io/{self.config.project_id}"
raise RuntimeError(
f"Default resource ID not supported for '{resource_type}' resource "
"type."
)
def _connect_to_resource(
self,
**kwargs: Any,
) -> Any:
"""Authenticate and connect to a GCP resource.
Initialize and return a session or client object depending on the
connector configuration:
- initialize and return generic google-auth credentials if the resource
type is a generic GCP resource
- initialize and return a google-storage client for an GCS resource type
For the Docker and Kubernetes resource types, the connector does not
support connecting to the resource directly. Instead, the connector
supports generating a connector client object for the resource type
in question.
Args:
kwargs: Additional implementation specific keyword arguments to pass
to the session or client constructor.
Returns:
Generic GCP credentials for GCP generic resources and a
google-storage GCS client for GCS resources.
Raises:
NotImplementedError: If the connector instance does not support
directly connecting to the indicated resource type.
"""
resource_type = self.resource_type
resource_id = self.resource_id
assert resource_type is not None
assert resource_id is not None
# Regardless of the resource type, we must authenticate to GCP first
# before we can connect to any GCP resource
credentials, _ = self.get_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
if resource_type == GCS_RESOURCE_TYPE:
# Validate that the resource ID is a valid GCS bucket name
self._parse_gcs_resource_id(resource_id)
# Create an GCS client for the bucket
client = storage.Client(
project=self.config.project_id, credentials=credentials
)
return client
if resource_type == GCP_RESOURCE_TYPE:
return credentials
raise NotImplementedError(
f"Connecting to {resource_type} resources is not directly "
"supported by the GCP connector. Please call the "
f"`get_connector_client` method to get a {resource_type} connector "
"instance for the resource."
)
def _configure_local_client(
self,
**kwargs: Any,
) -> None:
"""Configure a local client to authenticate and connect to a resource.
This method uses the connector's configuration to configure a local
client or SDK installed on the localhost for the indicated resource.
Args:
kwargs: Additional implementation specific keyword arguments to use
to configure the client.
Raises:
NotImplementedError: If the connector instance does not support
local configuration for the configured resource type or
authentication method.registry
"""
resource_type = self.resource_type
if resource_type in [GCP_RESOURCE_TYPE, GCS_RESOURCE_TYPE]:
gcloud_config_json: Optional[str] = None
# There is no way to configure the local gcloud CLI to use
# temporary OAuth 2.0 tokens. However, we can configure it to use
# the user account credentials or service account credentials
if self.auth_method == GCPAuthenticationMethods.USER_ACCOUNT:
assert isinstance(self.config, GCPUserAccountConfig)
# Use the user account credentials JSON to configure the
# local gcloud CLI
gcloud_config_json = (
self.config.user_account_json.get_secret_value()
)
elif self.auth_method == GCPAuthenticationMethods.SERVICE_ACCOUNT:
assert isinstance(self.config, GCPServiceAccountConfig)
# Use the service account credentials JSON to configure the
# local gcloud CLI
gcloud_config_json = (
self.config.service_account_json.get_secret_value()
)
if gcloud_config_json:
from google.auth import _cloud_sdk
# Dump the user account or service account credentials JSON to
# the local gcloud application default credentials file
adc_path = (
_cloud_sdk.get_application_default_credentials_path()
)
with open(adc_path, "w") as f:
f.write(gcloud_config_json)
logger.info(
f"Updated the local gcloud default application "
f"credentials file at '{adc_path}'"
)
return
raise NotImplementedError(
f"Local gcloud client configuration for resource type "
f"{resource_type} is only supported if the "
f"'{GCPAuthenticationMethods.USER_ACCOUNT}' or "
f"'{GCPAuthenticationMethods.SERVICE_ACCOUNT}' authentication "
f"method is used and only if the generation of temporary OAuth "
f"2.0 tokens is disabled by setting the "
f"'generate_temporary_tokens' option to 'False' in the "
f"service connector configuration."
)
raise NotImplementedError(
f"Configuring the local client for {resource_type} resources is "
"not directly supported by the GCP connector. Please call the "
f"`get_connector_client` method to get a {resource_type} connector "
"instance for the resource."
)
@classmethod
def _auto_configure(
cls,
auth_method: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
**kwargs: Any,
) -> "GCPServiceConnector":
"""Auto-configure the connector.
Instantiate a GCP connector with a configuration extracted from the
authentication configuration available in the environment (e.g.
environment variables or local GCP client/SDK configuration files).
Args:
auth_method: The particular authentication method to use. If not
specified, the connector implementation must decide which
authentication method to use or raise an exception.
resource_type: The type of resource to configure.
resource_id: The ID of the resource to configure. The
implementation may choose to either require or ignore this
parameter if it does not support or detect an resource type that
supports multiple instances.
kwargs: Additional implementation specific keyword arguments to use.
Returns:
A GCP connector instance configured with authentication credentials
automatically extracted from the environment.
Raises:
NotImplementedError: If the connector implementation does not
support auto-configuration for the specified authentication
method.
AuthorizationException: If no GCP credentials can be loaded from
the environment.
"""
auth_config: GCPBaseConfig
scopes = cls._get_scopes()
expires_at: Optional[datetime.datetime] = None
try:
# Determine the credentials from the environment
credentials, project_id = google.auth.default(
scopes=scopes,
)
except google.auth.exceptions.GoogleAuthError as e:
raise AuthorizationException(
f"No GCP credentials could be detected: {e}"
)
if project_id is None:
raise AuthorizationException(
"No GCP project ID could be detected. Please set the active "
"GCP project ID by running 'gcloud config set project'."
)
if auth_method == GCPAuthenticationMethods.IMPLICIT:
cls._check_implicit_auth_method_allowed()
auth_config = GCPBaseConfig(
project_id=project_id,
)
elif auth_method == GCPAuthenticationMethods.OAUTH2_TOKEN:
# Refresh the credentials if necessary, to fetch the access token
if not credentials.valid or not credentials.token:
try:
with requests.Session() as session:
req = Request(session)
credentials.refresh(req)
except google.auth.exceptions.GoogleAuthError as e:
raise AuthorizationException(
f"Could not fetch GCP OAuth2 token: {e}"
)
if not credentials.token:
raise AuthorizationException(
"Could not fetch GCP OAuth2 token"
)
auth_config = GCPOAuth2TokenConfig(
project_id=project_id,
token=credentials.token,
service_account_email=credentials.signer_email
if hasattr(credentials, "signer_email")
else None,
)
if credentials.expiry:
# Add the UTC timezone to the expiration time
expires_at = credentials.expiry.replace(
tzinfo=datetime.timezone.utc
)
else:
# Check if user account credentials are available
if isinstance(credentials, gcp_credentials.Credentials):
if auth_method not in [
GCPAuthenticationMethods.USER_ACCOUNT,
None,
]:
raise NotImplementedError(
f"Could not perform auto-configuration for "
f"authentication method {auth_method}. Only "
f"GCP user account credentials have been detected."
)
auth_method = GCPAuthenticationMethods.USER_ACCOUNT
user_account_json = json.dumps(
dict(
type="authorized_user",
client_id=credentials._client_id,
client_secret=credentials._client_secret,
refresh_token=credentials.refresh_token,
)
)
auth_config = GCPUserAccountConfig(
project_id=project_id,
user_account_json=user_account_json,
)
# Check if service account credentials are available
elif isinstance(credentials, gcp_service_account.Credentials):
if auth_method not in [
GCPAuthenticationMethods.SERVICE_ACCOUNT,
None,
]:
raise NotImplementedError(
f"Could not perform auto-configuration for "
f"authentication method {auth_method}. Only "
f"GCP service account credentials have been detected."
)
auth_method = GCPAuthenticationMethods.SERVICE_ACCOUNT
service_account_json_file = os.environ.get(
"GOOGLE_APPLICATION_CREDENTIALS"
)
if service_account_json_file is None:
# Shouldn't happen since google.auth.default() should
# already have loaded the credentials from the environment
raise AuthorizationException(
"No GCP service account credentials found in the "
"environment. Please set the "
"GOOGLE_APPLICATION_CREDENTIALS environment variable "
"to the path of the service account JSON file."
)
with open(service_account_json_file, "r") as f:
service_account_json = f.read()
auth_config = GCPServiceAccountConfig(
project_id=project_id,
service_account_json=service_account_json,
)
else:
raise AuthorizationException(
"No valid GCP credentials could be detected."
)
return cls(
auth_method=auth_method,
resource_type=resource_type,
resource_id=resource_id
if resource_type not in [GCP_RESOURCE_TYPE, None]
else None,
expires_at=expires_at,
config=auth_config,
)
def _verify(
self,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> List[str]:
"""Verify and list all the resources that the connector can access.
Args:
resource_type: The type of the resource to verify. If omitted and
if the connector supports multiple resource types, the
implementation must verify that it can authenticate and connect
to any and all of the supported resource types.
resource_id: The ID of the resource to connect to. Omitted if a
resource type is not specified. It has the same value as the
default resource ID if the supplied resource type doesn't
support multiple instances. If the supplied resource type does
allows multiple instances, this parameter may still be omitted
to fetch a list of resource IDs identifying all the resources
of the indicated type that the connector can access.
Returns:
The list of resources IDs in canonical format identifying the
resources that the connector can access. This list is empty only
if the resource type is not specified (i.e. for multi-type
connectors).
Raises:
AuthorizationException: If the connector cannot authenticate or
access the specified resource.
"""
# If the resource type is not specified, treat this the
# same as a generic GCP connector.
credentials, _ = self.get_session(
self.auth_method,
resource_type=resource_type or GCP_RESOURCE_TYPE,
resource_id=resource_id,
)
if not resource_type:
return []
if resource_type == GCP_RESOURCE_TYPE:
assert resource_id is not None
return [resource_id]
if resource_type == GCS_RESOURCE_TYPE:
gcs_client = storage.Client(
project=self.config.project_id, credentials=credentials
)
if not resource_id:
# List all GCS buckets
try:
buckets = gcs_client.list_buckets()
bucket_names = [bucket.name for bucket in buckets]
except google.api_core.exceptions.GoogleAPIError as e:
msg = f"failed to list GCS buckets: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
return [f"gs://{bucket}" for bucket in bucket_names]
else:
# Check if the specified GCS bucket exists
bucket_name = self._parse_gcs_resource_id(resource_id)
try:
gcs_client.get_bucket(bucket_name)
return [resource_id]
except google.api_core.exceptions.GoogleAPIError as e:
msg = f"failed to fetch GCS bucket {bucket_name}: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
assert resource_id is not None
# No way to verify a GCR registry without attempting to
# connect to it via Docker/OCI, so just return the resource ID.
return [resource_id]
if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
gke_client = container_v1.ClusterManagerClient(
credentials=credentials
)
# List all GKE clusters
try:
clusters = gke_client.list_clusters(
parent=f"projects/{self.config.project_id}/locations/-"
)
cluster_names = [cluster.name for cluster in clusters.clusters]
except google.api_core.exceptions.GoogleAPIError as e:
msg = f"Failed to list GKE clusters: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
if not resource_id:
return cluster_names
else:
# Check if the specified GKE cluster exists
cluster_name = self._parse_gke_resource_id(resource_id)
if cluster_name not in cluster_names:
raise AuthorizationException(
f"GKE cluster '{cluster_name}' not found or not "
"accessible."
)
return [resource_id]
return []
def _get_connector_client(
self,
resource_type: str,
resource_id: str,
) -> "ServiceConnector":
"""Get a connector instance that can be used to connect to a resource.
This method generates a client-side connector instance that can be used
to connect to a resource of the given type. The client-side connector
is configured with temporary GCP credentials extracted from the
current connector and, depending on resource type, it may also be
of a different connector type:
- a Kubernetes connector for Kubernetes clusters
- a Docker connector for Docker registries
Args:
resource_type: The type of the resources to connect to.
resource_id: The ID of a particular resource to connect to.
Returns:
A GCP, Kubernetes or Docker connector instance that can be used to
connect to the specified resource.
Raises:
AuthorizationException: If authentication failed.
ValueError: If the resource type is not supported.
RuntimeError: If the Kubernetes connector is not installed and the
resource type is Kubernetes.
"""
connector_name = ""
if self.name:
connector_name = self.name
if resource_id:
connector_name += f" ({resource_type} | {resource_id} client)"
else:
connector_name += f" ({resource_type} client)"
logger.debug(f"Getting connector client for {connector_name}")
credentials, expires_at = self.get_session(
self.auth_method,
resource_type=resource_type,
resource_id=resource_id,
)
if resource_type in [GCP_RESOURCE_TYPE, GCS_RESOURCE_TYPE]:
# By default, use the token extracted from the google credentials
# object
auth_method: str = GCPAuthenticationMethods.OAUTH2_TOKEN
config: GCPBaseConfig = GCPOAuth2TokenConfig(
project_id=self.config.project_id,
token=credentials.token,
service_account_email=credentials.signer_email
if hasattr(credentials, "signer_email")
else None,
)
# If the connector is explicitly configured to not generate
# temporary tokens, use the original config
if self.auth_method == GCPAuthenticationMethods.USER_ACCOUNT:
assert isinstance(self.config, GCPUserAccountConfig)
if not self.config.generate_temporary_tokens:
config = self.config
auth_method = self.auth_method
expires_at = None
elif self.auth_method == GCPAuthenticationMethods.SERVICE_ACCOUNT:
assert isinstance(self.config, GCPServiceAccountConfig)
if not self.config.generate_temporary_tokens:
config = self.config
auth_method = self.auth_method
expires_at = None
# Create a client-side GCP connector instance that is fully formed
# and ready to use to connect to the specified resource (i.e. has
# all the necessary configuration and credentials, a resource type
# and a resource ID where applicable)
return GCPServiceConnector(
id=self.id,
name=connector_name,
auth_method=auth_method,
resource_type=resource_type,
resource_id=resource_id,
config=config,
expires_at=expires_at,
)
if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
assert resource_id is not None
registry_id = self._parse_gcr_resource_id(resource_id)
# Create a client-side Docker connector instance with the temporary
# Docker credentials
return DockerServiceConnector(
id=self.id,
name=connector_name,
auth_method=DockerAuthenticationMethods.PASSWORD,
resource_type=resource_type,
config=DockerConfiguration(
username="oauth2accesstoken",
password=credentials.token,
registry=registry_id,
),
expires_at=expires_at,
)
if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
assert resource_id is not None
cluster_name = self._parse_gke_resource_id(resource_id)
gke_client = container_v1.ClusterManagerClient(
credentials=credentials
)
# List all GKE clusters
try:
clusters = gke_client.list_clusters(
parent=f"projects/{self.config.project_id}/locations/-"
)
cluster_map = {
cluster.name: cluster for cluster in clusters.clusters
}
except google.api_core.exceptions.GoogleAPIError as e:
msg = f"Failed to list GKE clusters: {e}"
logger.error(msg)
raise AuthorizationException(msg) from e
# Find the cluster with the specified name
if cluster_name not in cluster_map:
raise AuthorizationException(
f"GKE cluster '{cluster_name}' not found or not "
"accessible."
)
cluster = cluster_map[cluster_name]
# get cluster details
cluster_server = cluster.endpoint
cluster_ca_cert = cluster.master_auth.cluster_ca_certificate
bearer_token = credentials.token
# Create a client-side Kubernetes connector instance with the
# temporary Kubernetes credentials
try:
# Import libraries only when needed
from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import (
KubernetesAuthenticationMethods,
KubernetesServiceConnector,
KubernetesTokenConfig,
)
except ImportError as e:
raise RuntimeError(
f"The Kubernetes Service Connector functionality could not "
f"be used due to missing dependencies: {e}"
)
return KubernetesServiceConnector(
id=self.id,
name=connector_name,
auth_method=KubernetesAuthenticationMethods.TOKEN,
resource_type=resource_type,
config=KubernetesTokenConfig(
cluster_name=f"gke_{self.config.project_id}_{cluster_name}",
certificate_authority=cluster_ca_cert,
server=f"https://{cluster_server}",
token=bearer_token,
),
expires_at=expires_at,
)
raise ValueError(f"Unsupported resource type: {resource_type}")
get_session(self, auth_method, resource_type=None, resource_id=None)
Get a GCP session object with credentials for the specified resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_method |
str |
The authentication method to use. |
required |
resource_type |
Optional[str] |
The resource type to get credentials for. |
None |
resource_id |
Optional[str] |
The resource ID to get credentials for. |
None |
Returns:
Type | Description |
---|---|
Tuple[google.oauth2.service_account.Credentials, Optional[datetime.datetime]] |
GCP session with credentials for the specified resource and its expiration timestamp, if applicable. |
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def get_session(
self,
auth_method: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> Tuple[gcp_service_account.Credentials, Optional[datetime.datetime]]:
"""Get a GCP session object with credentials for the specified resource.
Args:
auth_method: The authentication method to use.
resource_type: The resource type to get credentials for.
resource_id: The resource ID to get credentials for.
Returns:
GCP session with credentials for the specified resource and its
expiration timestamp, if applicable.
"""
# We maintain a cache of all sessions to avoid re-authenticating
# multiple times for the same resource
key = (auth_method, resource_type, resource_id)
if key in self._session_cache:
session, expires_at = self._session_cache[key]
if expires_at is None:
return session, None
# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
if expires_at > now:
return session, expires_at
logger.debug(
f"Creating GCP authentication session for auth method "
f"'{auth_method}', resource type '{resource_type}' and resource ID "
f"'{resource_id}'..."
)
session, expires_at = self._authenticate(
auth_method, resource_type, resource_id
)
self._session_cache[key] = (session, expires_at)
return session, expires_at
GCPUserAccountConfig (GCPBaseConfig, GCPUserAccountCredentials)
pydantic-model
GCP user account configuration.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPUserAccountConfig(GCPBaseConfig, GCPUserAccountCredentials):
"""GCP user account configuration."""
GCPUserAccountCredentials (AuthenticationConfig)
pydantic-model
GCP user account credentials.
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPUserAccountCredentials(AuthenticationConfig):
"""GCP user account credentials."""
user_account_json: SecretStr = Field(
title="GCP User Account Credentials JSON",
)
generate_temporary_tokens: bool = Field(
default=True,
title="Generate temporary OAuth 2.0 tokens",
description="Whether to generate temporary OAuth 2.0 tokens from the "
"user account credentials JSON. If set to False, the connector will "
"distribute the user account credentials JSON to clients instead.",
)
@validator("user_account_json")
def validate_user_account_json(cls, v: SecretStr) -> SecretStr:
"""Validate the user account credentials JSON.
Args:
v: The user account credentials JSON.
Returns:
The validated user account credentials JSON.
Raises:
ValueError: If the user account credentials JSON is invalid.
"""
try:
user_account_info = json.loads(v.get_secret_value())
except json.JSONDecodeError as e:
raise ValueError(
f"GCP user account credentials is not a valid JSON: {e}"
)
# Check that all fields are present
required_fields = [
"type",
"refresh_token",
"client_secret",
"client_id",
]
# Compute missing fields
missing_fields = set(required_fields) - set(user_account_info.keys())
if missing_fields:
raise ValueError(
f"GCP user account credentials JSON is missing required "
f'fields: {", ".join(list(missing_fields))}'
)
if user_account_info["type"] != "authorized_user":
raise ValueError(
"The JSON does not contain GCP user account credentials. The "
f'"type" field is set to {user_account_info["type"]} '
"instead of 'authorized_user'."
)
return v
generate_temporary_tokens: bool
pydantic-field
Whether to generate temporary OAuth 2.0 tokens from the user account credentials JSON. If set to False, the connector will distribute the user account credentials JSON to clients instead.
validate_user_account_json(v)
classmethod
Validate the user account credentials JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
v |
SecretStr |
The user account credentials JSON. |
required |
Returns:
Type | Description |
---|---|
SecretStr |
The validated user account credentials JSON. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the user account credentials JSON is invalid. |
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
@validator("user_account_json")
def validate_user_account_json(cls, v: SecretStr) -> SecretStr:
"""Validate the user account credentials JSON.
Args:
v: The user account credentials JSON.
Returns:
The validated user account credentials JSON.
Raises:
ValueError: If the user account credentials JSON is invalid.
"""
try:
user_account_info = json.loads(v.get_secret_value())
except json.JSONDecodeError as e:
raise ValueError(
f"GCP user account credentials is not a valid JSON: {e}"
)
# Check that all fields are present
required_fields = [
"type",
"refresh_token",
"client_secret",
"client_id",
]
# Compute missing fields
missing_fields = set(required_fields) - set(user_account_info.keys())
if missing_fields:
raise ValueError(
f"GCP user account credentials JSON is missing required "
f'fields: {", ".join(list(missing_fields))}'
)
if user_account_info["type"] != "authorized_user":
raise ValueError(
"The JSON does not contain GCP user account credentials. The "
f'"type" field is set to {user_account_info["type"]} '
"instead of 'authorized_user'."
)
return v
step_operators
special
Initialization for the VertexAI Step Operator.
vertex_step_operator
Implementation of a VertexAI step operator.
Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py
VertexStepOperator (BaseStepOperator, GoogleCredentialsMixin)
Step operator to run a step on Vertex AI.
This class defines code that can set up a Vertex AI environment and run the ZenML entrypoint command in it.
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
"""Step operator to run a step on Vertex AI.
This class defines code that can set up a Vertex AI environment and run the
ZenML entrypoint command in it.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes the step operator and validates the accelerator type.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
@property
def config(self) -> VertexStepOperatorConfig:
"""Returns the `VertexStepOperatorConfig` config.
Returns:
The configuration.
"""
return cast(VertexStepOperatorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Vertex step operator.
Returns:
The settings class.
"""
return VertexStepOperatorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote container
registry and a remote artifact store.
"""
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
if stack.artifact_store.config.is_local:
return False, (
"The Vertex step operator runs code remotely and "
"needs to write files into the artifact store, but the "
f"artifact store `{stack.artifact_store.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote artifact store when using the Vertex "
"step operator."
)
container_registry = stack.container_registry
assert container_registry is not None
if container_registry.config.is_local:
return False, (
"The Vertex step operator runs code remotely and "
"needs to push/pull Docker images, but the "
f"container registry `{container_registry.name}` of the "
"active stack is local. Please ensure that your stack "
"contains a remote container registry when using the "
"Vertex step operator."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def get_docker_builds(
self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=VERTEX_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
)
builds.append(build)
return builds
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on VertexAI.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
resource_settings = info.config.resource_settings
if resource_settings.cpu_count or resource_settings.memory:
logger.warning(
"Specifying cpus or memory is not supported for "
"the Vertex step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different machine_type type like this: "
"`zenml step-operator update %s "
"--machine_type=<MACHINE_TYPE>`",
self.name,
)
settings = cast(VertexStepOperatorSettings, self.get_settings(info))
validate_accelerator_type(settings.accelerator_type)
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
image_name = info.get_image(key=VERTEX_DOCKER_IMAGE_KEY)
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
accelerator_count = (
resource_settings.gpu_count or settings.accelerator_count
)
custom_job = {
"display_name": info.run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": settings.machine_type,
"accelerator_type": settings.accelerator_type,
"accelerator_count": accelerator_count
if settings.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": entrypoint_command,
"args": [],
"env": [
{"name": key, "value": value}
for key, value in environment.items()
],
},
}
]
},
"labels": job_labels,
"encryption_spec": {
"kmsKeyName": self.config.encryption_spec_key_name
}
if self.config.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = f"projects/{project_id}/locations/{self.config.region}"
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state
# periodically, and retries the polling when a transient connectivity
# issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELED are considered to be final states.
# The following logic will keep polling the state of the job until
# the job enters a final state.
#
# During the polling, if a connection error was encountered, the GET
# request will be retried by recreating the Python API client to
# refresh the lifecycle of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
# will raise ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
config: VertexStepOperatorConfig
property
readonly
Returns the VertexStepOperatorConfig
config.
Returns:
Type | Description |
---|---|
VertexStepOperatorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Vertex step operator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote container registry and a remote artifact store. |
__init__(self, *args, **kwargs)
special
Initializes the step operator and validates the accelerator type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Variable length argument list. |
() |
**kwargs |
Any |
Arbitrary keyword arguments. |
{} |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes the step operator and validates the accelerator type.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
get_docker_builds(self, deployment)
Gets the Docker builds required for the component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentBaseModel |
The pipeline deployment for which to get the builds. |
required |
Returns:
Type | Description |
---|---|
List[BuildConfiguration] |
The required Docker builds. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def get_docker_builds(
self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
builds = []
for step_name, step in deployment.step_configurations.items():
if step.config.step_operator == self.name:
build = BuildConfiguration(
key=VERTEX_DOCKER_IMAGE_KEY,
settings=step.config.docker_settings,
step_name=step_name,
)
builds.append(build)
return builds
launch(self, info, entrypoint_command, environment)
Launches a step on VertexAI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
environment |
Dict[str, str] |
Environment variables to set in the step operator environment. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run fails. |
ConnectionError |
If the run fails due to a connection error. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
environment: Dict[str, str],
) -> None:
"""Launches a step on VertexAI.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
environment: Environment variables to set in the step operator
environment.
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
resource_settings = info.config.resource_settings
if resource_settings.cpu_count or resource_settings.memory:
logger.warning(
"Specifying cpus or memory is not supported for "
"the Vertex step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different machine_type type like this: "
"`zenml step-operator update %s "
"--machine_type=<MACHINE_TYPE>`",
self.name,
)
settings = cast(VertexStepOperatorSettings, self.get_settings(info))
validate_accelerator_type(settings.accelerator_type)
job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}
# Step 1: Authenticate with Google
credentials, project_id = self._get_authentication()
image_name = info.get_image(key=VERTEX_DOCKER_IMAGE_KEY)
# Step 3: Launch the job
# The AI Platform services require regional API endpoints.
client_options = {
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(
credentials=credentials, client_options=client_options
)
accelerator_count = (
resource_settings.gpu_count or settings.accelerator_count
)
custom_job = {
"display_name": info.run_name,
"job_spec": {
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": settings.machine_type,
"accelerator_type": settings.accelerator_type,
"accelerator_count": accelerator_count
if settings.accelerator_type
else 0,
},
"replica_count": 1,
"container_spec": {
"image_uri": image_name,
"command": entrypoint_command,
"args": [],
"env": [
{"name": key, "value": value}
for key, value in environment.items()
],
},
}
]
},
"labels": job_labels,
"encryption_spec": {
"kmsKeyName": self.config.encryption_spec_key_name
}
if self.config.encryption_spec_key_name
else {},
}
logger.debug("Vertex AI Job=%s", custom_job)
parent = f"projects/{project_id}/locations/{self.config.region}"
logger.info(
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
custom_job["display_name"],
parent,
)
response = client.create_custom_job(
parent=parent, custom_job=custom_job
)
logger.debug("Vertex AI response:", response)
# Step 4: Monitor the job
# Monitors the long-running operation by polling the job state
# periodically, and retries the polling when a transient connectivity
# issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELED are considered to be final states.
# The following logic will keep polling the state of the job until
# the job enters a final state.
#
# During the polling, if a connection error was encountered, the GET
# request will be retried by recreating the Python API client to
# refresh the lifecycle of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
# will raise ConnectionError.
retry_count = 0
job_id = response.name
while response.state not in VERTEX_JOB_STATES_COMPLETED:
time.sleep(POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
)
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
)
else:
logger.error(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
job_id, response
)
)
logger.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logger.info("Job '%s' successful.", job_id)
validate_accelerator_type(accelerator_type=None)
Validates that the accelerator type is valid.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
accelerator_type |
Optional[str] |
The accelerator type to validate. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If the accelerator type is not valid. |
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None:
"""Validates that the accelerator type is valid.
Args:
accelerator_type: The accelerator type to validate.
Raises:
ValueError: If the accelerator type is not valid.
"""
accepted_vals = list(aiplatform.gapic.AcceleratorType.__members__.keys())
if accelerator_type and accelerator_type.upper() not in accepted_vals:
raise ValueError(
f"Accelerator must be one of the following: {accepted_vals}"
)