Skip to content

Gcp

zenml.integrations.gcp special

Initialization of the GCP ZenML integration.

The GCP integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores and provides an io module to handle file operations on Google Cloud Storage (GCS).

The Vertex AI integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.

GcpIntegration (Integration)

Definition of Google Cloud Platform integration for ZenML.

Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
    """Definition of Google Cloud Platform integration for ZenML."""

    NAME = GCP
    REQUIREMENTS = [
        "kfp>=2.6.0",
        "gcsfs",
        "google-cloud-secret-manager",
        "google-cloud-container>=2.21.0",
        "google-cloud-artifact-registry>=1.11.3",
        "google-cloud-storage>=2.9.0",
        "google-cloud-aiplatform>=1.34.0",  # includes shapely pin fix
        "google-cloud-build>=3.11.0",
        "kubernetes",
    ]
    REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes","kfp"]

    @classmethod
    def activate(cls) -> None:
        """Activate the GCP integration."""
        from zenml.integrations.gcp import service_connectors  # noqa

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the GCP integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.gcp.flavors import (
            GCPArtifactStoreFlavor,
            GCPImageBuilderFlavor,
            VertexOrchestratorFlavor,
            VertexStepOperatorFlavor,
        )

        return [
            GCPArtifactStoreFlavor,
            GCPImageBuilderFlavor,
            VertexOrchestratorFlavor,
            VertexStepOperatorFlavor,
        ]

activate() classmethod

Activate the GCP integration.

Source code in zenml/integrations/gcp/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the GCP integration."""
    from zenml.integrations.gcp import service_connectors  # noqa

flavors() classmethod

Declare the stack component flavors for the GCP integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/gcp/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the GCP integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.gcp.flavors import (
        GCPArtifactStoreFlavor,
        GCPImageBuilderFlavor,
        VertexOrchestratorFlavor,
        VertexStepOperatorFlavor,
    )

    return [
        GCPArtifactStoreFlavor,
        GCPImageBuilderFlavor,
        VertexOrchestratorFlavor,
        VertexStepOperatorFlavor,
    ]

artifact_stores special

Initialization of the GCP Artifact Store.

gcp_artifact_store

Implementation of the GCP Artifact Store.

GCPArtifactStore (BaseArtifactStore, AuthenticationMixin)

Artifact Store for Google Cloud Storage based artifacts.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for Google Cloud Storage based artifacts."""

    _filesystem: Optional[gcsfs.GCSFileSystem] = None

    @property
    def config(self) -> GCPArtifactStoreConfig:
        """Returns the `GCPArtifactStoreConfig` config.

        Returns:
            The configuration.
        """
        return cast(GCPArtifactStoreConfig, self._config)

    def get_credentials(
        self,
    ) -> Optional[Union[Dict[str, Any], gcp_credentials.Credentials]]:
        """Returns the credentials for the GCP Artifact Store if configured.

        Returns:
            The credentials.

        Raises:
            RuntimeError: If the linked connector returns the wrong type of
                client.
        """
        connector = self.get_connector()
        if connector:
            client = connector.connect()
            if not isinstance(client, storage.Client):
                raise RuntimeError(
                    f"Expected a google.cloud.storage.Client while trying to "
                    f"use the linked connector, but got {type(client)}."
                )
            return client._credentials

        secret = self.get_typed_authentication_secret(
            expected_schema_type=GCPSecretSchema
        )
        return secret.get_credential_dict() if secret else None

    @property
    def filesystem(self) -> gcsfs.GCSFileSystem:
        """The gcsfs filesystem to access this artifact store.

        Returns:
            The gcsfs filesystem to access this artifact store.
        """
        # Refresh the credentials also if the connector has expired
        if self._filesystem and not self.connector_has_expired():
            return self._filesystem

        token = self.get_credentials()
        self._filesystem = gcsfs.GCSFileSystem(token=token)

        return self._filesystem

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object that can be used to read or write to the file.
        """
        if mode in ("a", "ab"):
            logger.warning(
                "GCS Filesystem is immutable, so append mode will overwrite existing files."
            )
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )
        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
          in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        return [
            f"{GCP_PATH_PREFIX}{path}"
            for path in self.filesystem.glob(path=pattern)
        ]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path of the directory to list.

        Returns:
            A list of paths of files in the directory.
        """
        path_without_prefix = convert_to_str(path)
        if path_without_prefix.startswith(GCP_PATH_PREFIX):
            path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a file info dict returned by GCP.

            Args:
                file_dict: A file info dict returned by the GCP filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["name"])
            base_name = file_path[len(path_without_prefix) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
            # gcsfs.listdir also returns the root directory, so we filter
            # it out here
            if _extract_basename(dict_)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path of the directory to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path of the directory to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path of the file to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: the path to get stat info for.

        Returns:
            A dictionary with the stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def size(self, path: PathType) -> int:
        """Get the size of a file in bytes.

        Args:
            path: The path to the file.

        Returns:
            The size of the file in bytes.
        """
        return self.filesystem.size(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        for (
            directory,
            subdirectories,
            files,
        ) in self.filesystem.walk(path=top):
            yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
config: GCPArtifactStoreConfig property readonly

Returns the GCPArtifactStoreConfig config.

Returns:

Type Description
GCPArtifactStoreConfig

The configuration.

filesystem: gcsfs.GCSFileSystem property readonly

The gcsfs filesystem to access this artifact store.

Returns:

Type Description
gcsfs.GCSFileSystem

The gcsfs filesystem to access this artifact store.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )
    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
get_credentials(self)

Returns the credentials for the GCP Artifact Store if configured.

Returns:

Type Description
Union[Dict[str, Any], google.oauth2.credentials.Credentials]

The credentials.

Exceptions:

Type Description
RuntimeError

If the linked connector returns the wrong type of client.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def get_credentials(
    self,
) -> Optional[Union[Dict[str, Any], gcp_credentials.Credentials]]:
    """Returns the credentials for the GCP Artifact Store if configured.

    Returns:
        The credentials.

    Raises:
        RuntimeError: If the linked connector returns the wrong type of
            client.
    """
    connector = self.get_connector()
    if connector:
        client = connector.connect()
        if not isinstance(client, storage.Client):
            raise RuntimeError(
                f"Expected a google.cloud.storage.Client while trying to "
                f"use the linked connector, but got {type(client)}."
            )
        return client._credentials

    secret = self.get_typed_authentication_secret(
        expected_schema_type=GCPSecretSchema
    )
    return secret.get_credential_dict() if secret else None
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
      in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    return [
        f"{GCP_PATH_PREFIX}{path}"
        for path in self.filesystem.glob(path=pattern)
    ]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths of files in the directory.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path of the directory to list.

    Returns:
        A list of paths of files in the directory.
    """
    path_without_prefix = convert_to_str(path)
    if path_without_prefix.startswith(GCP_PATH_PREFIX):
        path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a file info dict returned by GCP.

        Args:
            file_dict: A file info dict returned by the GCP filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["name"])
        base_name = file_path[len(path_without_prefix) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
        # gcsfs.listdir also returns the root directory, so we filter
        # it out here
        if _extract_basename(dict_)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to create.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path of the directory to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to create.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path of the directory to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object that can be used to read or write to the file.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object that can be used to read or write to the file.
    """
    if mode in ("a", "ab"):
        logger.warning(
            "GCS Filesystem is immutable, so append mode will overwrite existing files."
        )
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the file to remove.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path of the file to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
size(self, path)

Get the size of a file in bytes.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to the file.

required

Returns:

Type Description
int

The size of the file in bytes.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def size(self, path: PathType) -> int:
    """Get the size of a file in bytes.

    Args:
        path: The path to the file.

    Returns:
        The size of the file in bytes.
    """
    return self.filesystem.size(path=path)  # type: ignore[no-any-return]
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

the path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

A dictionary with the stat info.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: the path to get stat info for.

    Returns:
        A dictionary with the stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    for (
        directory,
        subdirectories,
        files,
    ) in self.filesystem.walk(path=top):
        yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files

constants

Constants for the VertexAI integration.

flavors special

GCP integration flavors.

gcp_artifact_store_flavor

GCP artifact store flavor.

GCPArtifactStoreConfig (BaseArtifactStoreConfig, AuthenticationConfigMixin)

Configuration for GCP Artifact Store.

Source code in zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py
class GCPArtifactStoreConfig(
    BaseArtifactStoreConfig, AuthenticationConfigMixin
):
    """Configuration for GCP Artifact Store."""

    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX}
    IS_IMMUTABLE_FILESYSTEM: ClassVar[bool] = True
GCPArtifactStoreFlavor (BaseArtifactStoreFlavor)

Flavor of the GCP artifact store.

Source code in zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py
class GCPArtifactStoreFlavor(BaseArtifactStoreFlavor):
    """Flavor of the GCP artifact store."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return GCP_ARTIFACT_STORE_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(
            resource_type=GCS_RESOURCE_TYPE,
            resource_id_attr="path",
        )

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png"

    @property
    def config_class(self) -> Type[GCPArtifactStoreConfig]:
        """Returns GCPArtifactStoreConfig config class.

        Returns:
                The config class.
        """
        return GCPArtifactStoreConfig

    @property
    def implementation_class(self) -> Type["GCPArtifactStore"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.gcp.artifact_stores import GCPArtifactStore

        return GCPArtifactStore
config_class: Type[zenml.integrations.gcp.flavors.gcp_artifact_store_flavor.GCPArtifactStoreConfig] property readonly

Returns GCPArtifactStoreConfig config class.

Returns:

Type Description
Type[zenml.integrations.gcp.flavors.gcp_artifact_store_flavor.GCPArtifactStoreConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[GCPArtifactStore] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[GCPArtifactStore]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

gcp_image_builder_flavor

Google Cloud image builder flavor.

GCPImageBuilderConfig (BaseImageBuilderConfig, GoogleCredentialsConfigMixin)

Google Cloud Builder image builder configuration.

Attributes:

Name Type Description
cloud_builder_image str

The name of the Docker image to use for the build steps. Defaults to gcr.io/cloud-builders/docker.

network str

The network name to which the build container will be attached while building the Docker image. More information about this: https://cloud.google.com/build/docs/build-config-file-schema#network. Defaults to cloudbuild.

build_timeout int

The timeout of the build in seconds. More information about this parameter: https://cloud.google.com/build/docs/build-config-file-schema#timeout_2 Defaults to 3600.

Source code in zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py
class GCPImageBuilderConfig(
    BaseImageBuilderConfig, GoogleCredentialsConfigMixin
):
    """Google Cloud Builder image builder configuration.

    Attributes:
        cloud_builder_image: The name of the Docker image to use for the build
            steps. Defaults to `gcr.io/cloud-builders/docker`.
        network: The network name to which the build container will be
            attached while building the Docker image. More information about
            this:
            https://cloud.google.com/build/docs/build-config-file-schema#network.
            Defaults to `cloudbuild`.
        build_timeout: The timeout of the build in seconds. More information
            about this parameter:
            https://cloud.google.com/build/docs/build-config-file-schema#timeout_2
            Defaults to `3600`.
    """

    cloud_builder_image: str = DEFAULT_CLOUD_BUILDER_IMAGE
    network: str = DEFAULT_CLOUD_BUILDER_NETWORK
    build_timeout: PositiveInt = DEFAULT_CLOUD_BUILD_TIMEOUT
GCPImageBuilderFlavor (BaseImageBuilderFlavor)

Google Cloud Builder image builder flavor.

Source code in zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py
class GCPImageBuilderFlavor(BaseImageBuilderFlavor):
    """Google Cloud Builder image builder flavor."""

    @property
    def name(self) -> str:
        """The flavor name.

        Returns:
            The name of the flavor.
        """
        return GCP_IMAGE_BUILDER_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(
            connector_type=GCP_CONNECTOR_TYPE,
            resource_type=GCP_RESOURCE_TYPE,
        )

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/gcp.png"

    @property
    def config_class(self) -> Type[BaseImageBuilderConfig]:
        """The config class.

        Returns:
            The config class.
        """
        return GCPImageBuilderConfig

    @property
    def implementation_class(self) -> Type["GCPImageBuilder"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.gcp.image_builders import GCPImageBuilder

        return GCPImageBuilder
config_class: Type[zenml.image_builders.base_image_builder.BaseImageBuilderConfig] property readonly

The config class.

Returns:

Type Description
Type[zenml.image_builders.base_image_builder.BaseImageBuilderConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[GCPImageBuilder] property readonly

Implementation class.

Returns:

Type Description
Type[GCPImageBuilder]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

The flavor name.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

vertex_orchestrator_flavor

Vertex orchestrator flavor.

VertexOrchestratorConfig (BaseOrchestratorConfig, GoogleCredentialsConfigMixin, VertexOrchestratorSettings)

Configuration for the Vertex orchestrator.

Attributes:

Name Type Description
location str

Name of GCP region where the pipeline job will be executed. Vertex AI Pipelines is available in the following regions: https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability

pipeline_root Optional[str]

a Cloud Storage URI that will be used by the Vertex AI Pipelines. If not provided but the artifact store in the stack used to execute the pipeline is a zenml.integrations.gcp.artifact_stores.GCPArtifactStore, then a subdirectory of the artifact store will be used.

encryption_spec_key_name Optional[str]

The Cloud KMS resource identifier of the customer managed encryption key used to protect the job. Has the form: projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY> . The key needs to be in the same region as where the compute resource is created.

workload_service_account Optional[str]

the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If not provided, the Compute Engine default service account for the GCP project in which the pipeline is running is used.

function_service_account Optional[str]

the service account for cloud function run-as account, for scheduled pipelines. This service account must have the act-as permission on the workload_service_account. If not provided, the Compute Engine default service account for the GCP project in which the pipeline is running is used.

scheduler_service_account Optional[str]

the service account used by the Google Cloud Scheduler to trigger and authenticate to the pipeline Cloud Function on a schedule. If not provided, the Compute Engine default service account for the GCP project in which the pipeline is running is used.

network Optional[str]

the full name of the Compute Engine Network to which the job should be peered. For example, projects/12345/global/networks/myVPC If not provided, the job will not be peered with any network.

cpu_limit Optional[str]

The maximum CPU limit for this operator. This string value can be a number (integer value for number of CPUs) as string, or a number followed by "m", which means 1/1000. You can specify at most 96 CPUs. (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)

memory_limit Optional[str]

The maximum memory limit for this operator. This string value can be a number, or a number followed by "K" (kilobyte), "M" (megabyte), or "G" (gigabyte). At most 624GB is supported.

gpu_limit Optional[int]

The GPU limit (positive number) for the operator. For more information about GPU resources, see: https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus

Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorConfig(
    BaseOrchestratorConfig,
    GoogleCredentialsConfigMixin,
    VertexOrchestratorSettings,
):
    """Configuration for the Vertex orchestrator.

    Attributes:
        location: Name of GCP region where the pipeline job will be executed.
            Vertex AI Pipelines is available in the following regions:
            https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability
        pipeline_root: a Cloud Storage URI that will be used by the Vertex AI
            Pipelines. If not provided but the artifact store in the stack used
            to execute the pipeline is a
            `zenml.integrations.gcp.artifact_stores.GCPArtifactStore`,
            then a subdirectory of the artifact store will be used.
        encryption_spec_key_name: The Cloud KMS resource identifier of the
            customer managed encryption key used to protect the job. Has the form:
            `projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY>`
            . The key needs to be in the same region as where the compute
            resource is created.
        workload_service_account: the service account for workload run-as
            account. Users submitting jobs must have act-as permission on this
            run-as account. If not provided, the Compute Engine default service
            account for the GCP project in which the pipeline is running is
            used.
        function_service_account: the service account for cloud function run-as
            account, for scheduled pipelines. This service account must have
            the act-as permission on the workload_service_account.
            If not provided, the Compute Engine default service account for the
            GCP project in which the pipeline is running is used.
        scheduler_service_account: the service account used by the Google Cloud
            Scheduler to trigger and authenticate to the pipeline Cloud Function
            on a schedule. If not provided, the Compute Engine default service
            account for the GCP project in which the pipeline is running is
            used.
        network: the full name of the Compute Engine Network to which the job
            should be peered. For example, `projects/12345/global/networks/myVPC`
            If not provided, the job will not be peered with any network.
        cpu_limit: The maximum CPU limit for this operator. This string value
            can be a number (integer value for number of CPUs) as string,
            or a number followed by "m", which means 1/1000. You can specify
            at most 96 CPUs.
            (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types)
        memory_limit: The maximum memory limit for this operator. This string
            value can be a number, or a number followed by "K" (kilobyte),
            "M" (megabyte), or "G" (gigabyte). At most 624GB is supported.
        gpu_limit: The GPU limit (positive number) for the operator.
            For more information about GPU resources, see:
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
    """

    location: str
    pipeline_root: Optional[str] = None
    encryption_spec_key_name: Optional[str] = None
    workload_service_account: Optional[str] = None
    function_service_account: Optional[str] = None
    scheduler_service_account: Optional[str] = None
    network: Optional[str] = None

    cpu_limit: Optional[str] = None
    memory_limit: Optional[str] = None
    gpu_limit: Optional[int] = None

    _resource_deprecation = deprecation_utils.deprecate_pydantic_attributes(
        "cpu_limit",
        "memory_limit",
        "gpu_limit",
        "function_service_account",
        "scheduler_service_account",
    )

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True

    @property
    def is_synchronous(self) -> bool:
        """Whether the orchestrator runs synchronous or not.

        Returns:
            Whether the orchestrator runs synchronous or not.
        """
        return self.synchronous

    @property
    def is_schedulable(self) -> bool:
        """Whether the orchestrator is schedulable or not.

        Returns:
            Whether the orchestrator is schedulable or not.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

is_schedulable: bool property readonly

Whether the orchestrator is schedulable or not.

Returns:

Type Description
bool

Whether the orchestrator is schedulable or not.

is_synchronous: bool property readonly

Whether the orchestrator runs synchronous or not.

Returns:

Type Description
bool

Whether the orchestrator runs synchronous or not.

VertexOrchestratorFlavor (BaseOrchestratorFlavor)

Vertex Orchestrator flavor.

Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorFlavor(BaseOrchestratorFlavor):
    """Vertex Orchestrator flavor."""

    @property
    def name(self) -> str:
        """Name of the orchestrator flavor.

        Returns:
            Name of the orchestrator flavor.
        """
        return GCP_VERTEX_ORCHESTRATOR_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(
            resource_type=GCP_RESOURCE_TYPE,
        )

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/vertexai.png"

    @property
    def config_class(self) -> Type[VertexOrchestratorConfig]:
        """Returns VertexOrchestratorConfig config class.

        Returns:
                The config class.
        """
        return VertexOrchestratorConfig

    @property
    def implementation_class(self) -> Type["VertexOrchestrator"]:
        """Implementation class for this flavor.

        Returns:
            Implementation class for this flavor.
        """
        from zenml.integrations.gcp.orchestrators import VertexOrchestrator

        return VertexOrchestrator
config_class: Type[zenml.integrations.gcp.flavors.vertex_orchestrator_flavor.VertexOrchestratorConfig] property readonly

Returns VertexOrchestratorConfig config class.

Returns:

Type Description
Type[zenml.integrations.gcp.flavors.vertex_orchestrator_flavor.VertexOrchestratorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[VertexOrchestrator] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[VertexOrchestrator]

Implementation class for this flavor.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the orchestrator flavor.

Returns:

Type Description
str

Name of the orchestrator flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

VertexOrchestratorSettings (BaseSettings)

Settings for the Vertex orchestrator.

Attributes:

Name Type Description
synchronous bool

If True, the client running a pipeline using this orchestrator waits until all steps finish running. If False, the client returns immediately and the pipeline is executed asynchronously. Defaults to True.

labels Dict[str, str]

Labels to assign to the pipeline job.

node_selector_constraint Optional[Tuple[str, str]]

Each constraint is a key-value pair label. For the container to be eligible to run on a node, the node must have each of the constraints appeared as labels. For example a GPU type can be providing by one of the following tuples: - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4") - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100") Hint: the selected region (location) must provide the requested accelerator (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).

pod_settings Optional[zenml.integrations.kubernetes.pod_settings.KubernetesPodSettings]

Pod settings to apply.

Source code in zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py
class VertexOrchestratorSettings(BaseSettings):
    """Settings for the Vertex orchestrator.

    Attributes:
        synchronous: If `True`, the client running a pipeline using this
            orchestrator waits until all steps finish running. If `False`,
            the client returns immediately and the pipeline is executed
            asynchronously. Defaults to `True`.
        labels: Labels to assign to the pipeline job.
        node_selector_constraint: Each constraint is a key-value pair label.
            For the container to be eligible to run on a node, the node must have
            each of the constraints appeared as labels.
            For example a GPU type can be providing by one of the following tuples:
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4")
                - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100")
            Hint: the selected region (location) must provide the requested accelerator
            (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).
        pod_settings: Pod settings to apply.
    """

    labels: Dict[str, str] = {}
    synchronous: bool = True
    node_selector_constraint: Optional[Tuple[str, str]] = None
    pod_settings: Optional[KubernetesPodSettings] = None

    _node_selector_deprecation = (
        deprecation_utils.deprecate_pydantic_attributes(
            "node_selector_constraint"
        )
    )

vertex_step_operator_flavor

Vertex step operator flavor.

VertexStepOperatorConfig (BaseStepOperatorConfig, GoogleCredentialsConfigMixin, VertexStepOperatorSettings)

Configuration for the Vertex step operator.

Attributes:

Name Type Description
region str

Region name, e.g., europe-west1.

encryption_spec_key_name Optional[str]

Encryption spec key name.

network Optional[str]

The full name of the Compute Engine network to which the Job should be peered. For example, projects/12345/global/networks/myVPC

reserved_ip_ranges Optional[str]

A list of names for the reserved ip ranges under the VPC network that can be used for this job. If set, we will deploy the job within the provided ip ranges. Otherwise, the job will be deployed to any ip ranges under the provided VPC network.

service_account Optional[str]

Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account.

Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorConfig(
    BaseStepOperatorConfig,
    GoogleCredentialsConfigMixin,
    VertexStepOperatorSettings,
):
    """Configuration for the Vertex step operator.

    Attributes:
        region: Region name, e.g., `europe-west1`.
        encryption_spec_key_name: Encryption spec key name.
        network: The full name of the Compute Engine network to which the Job should be peered.
            For example, projects/12345/global/networks/myVPC
        reserved_ip_ranges: A list of names for the reserved ip ranges under the VPC network that can be used
            for this job. If set, we will deploy the job within the provided ip ranges. Otherwise, the job
            will be deployed to any ip ranges under the provided VPC network.
        service_account: Specifies the service account for workload run-as account. Users submitting jobs
            must have act-as permission on this run-as account.
    """

    region: str

    # customer managed encryption key resource name
    # will be applied to all Vertex AI resources if set
    encryption_spec_key_name: Optional[str] = None

    network: Optional[str] = None

    reserved_ip_ranges: Optional[str] = None

    service_account: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

VertexStepOperatorFlavor (BaseStepOperatorFlavor)

Vertex Step Operator flavor.

Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorFlavor(BaseStepOperatorFlavor):
    """Vertex Step Operator flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            Name of the flavor.
        """
        return GCP_VERTEX_STEP_OPERATOR_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(
            resource_type=GCP_RESOURCE_TYPE,
        )

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/vertexai.png"

    @property
    def config_class(self) -> Type[VertexStepOperatorConfig]:
        """Returns `VertexStepOperatorConfig` config class.

        Returns:
                The config class.
        """
        return VertexStepOperatorConfig

    @property
    def implementation_class(self) -> Type["VertexStepOperator"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.gcp.step_operators import VertexStepOperator

        return VertexStepOperator
config_class: Type[zenml.integrations.gcp.flavors.vertex_step_operator_flavor.VertexStepOperatorConfig] property readonly

Returns VertexStepOperatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.gcp.flavors.vertex_step_operator_flavor.VertexStepOperatorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[VertexStepOperator] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[VertexStepOperator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

Name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.v2.misc.service_connector_type.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

VertexStepOperatorSettings (BaseSettings)

Settings for the Vertex step operator.

Attributes:

Name Type Description
accelerator_type Optional[str]

Defines which accelerator (GPU, TPU) is used for the job. Check out out this table to see which accelerator type and count are compatible with your chosen machine type: https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.

accelerator_count int

Defines number of accelerators to be used for the job. Check out out this table to see which accelerator type and count are compatible with your chosen machine type: https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.

machine_type str

Machine type specified here https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.

boot_disk_size_gb int

Size of the boot disk in GB. (Default: 100) https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options

boot_disk_type str

Type of the boot disk. (Default: pd-ssd) https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options

Source code in zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py
class VertexStepOperatorSettings(BaseSettings):
    """Settings for the Vertex step operator.

    Attributes:
        accelerator_type: Defines which accelerator (GPU, TPU) is used for the
            job. Check out out this table to see which accelerator
            type and count are compatible with your chosen machine type:
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
        accelerator_count: Defines number of accelerators to be used for the
            job. Check out out this table to see which accelerator
            type and count are compatible with your chosen machine type:
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
        machine_type: Machine type specified here
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
        boot_disk_size_gb: Size of the boot disk in GB. (Default: 100)
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
        boot_disk_type: Type of the boot disk. (Default: pd-ssd)
            https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options

    """

    accelerator_type: Optional[str] = None
    accelerator_count: int = 0
    machine_type: str = "n1-standard-4"
    boot_disk_size_gb: int = 100
    boot_disk_type: str = "pd-ssd"

google_credentials_mixin

Implementation of the Google credentials mixin.

GoogleCredentialsConfigMixin (StackComponentConfig)

Config mixin for Google Cloud Platform credentials.

Attributes:

Name Type Description
project Optional[str]

GCP project name. If None, the project will be inferred from the environment.

service_account_path Optional[str]

path to the service account credentials file to be used for authentication. If not provided, the default credentials will be used.

Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsConfigMixin(StackComponentConfig):
    """Config mixin for Google Cloud Platform credentials.

    Attributes:
        project: GCP project name. If `None`, the project will be inferred from
            the environment.
        service_account_path: path to the service account credentials file to be
            used for authentication. If not provided, the default credentials
            will be used.
    """

    project: Optional[str] = None
    service_account_path: Optional[str] = None

GoogleCredentialsMixin (StackComponent)

StackComponent mixin to get Google Cloud Platform credentials.

Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsMixin(StackComponent):
    """StackComponent mixin to get Google Cloud Platform credentials."""

    @property
    def config(self) -> GoogleCredentialsConfigMixin:
        """Returns the `GoogleCredentialsConfigMixin` config.

        Returns:
            The configuration.
        """
        return cast(GoogleCredentialsConfigMixin, self._config)

    def _get_authentication(self) -> Tuple["Credentials", str]:
        """Get GCP credentials and the project ID associated with the credentials.

        If `service_account_path` is provided, then the credentials will be
        loaded from the file at that path. Otherwise, the default credentials
        will be used.

        Returns:
            A tuple containing the credentials and the project ID associated to
            the credentials.

        Raises:
            RuntimeError: If the linked connector returns an unexpected type of
                credentials.
        """
        from google.auth import default, load_credentials_from_file
        from google.auth.credentials import Credentials

        from zenml.integrations.gcp.service_connectors import (
            GCPServiceConnector,
        )

        connector = self.get_connector()
        if connector:
            credentials = connector.connect()
            if not isinstance(credentials, Credentials) or not isinstance(
                connector, GCPServiceConnector
            ):
                raise RuntimeError(
                    f"Expected google.auth.credentials.Credentials while "
                    "trying to use the linked connector, but got "
                    f"{type(credentials)}."
                )
            return credentials, connector.config.gcp_project_id

        if self.config.service_account_path:
            credentials, project_id = load_credentials_from_file(
                self.config.service_account_path
            )
        else:
            credentials, project_id = default()

        if self.config.project and self.config.project != project_id:
            logger.warning(
                "Authenticated with project `%s`, but this %s is "
                "configured to use the project `%s`.",
                project_id,
                self.type,
                self.config.project,
            )

        # If the project was set in the configuration, use it. Otherwise, use
        # the project that was used to authenticate.
        project_id = self.config.project if self.config.project else project_id
        return credentials, project_id
config: GoogleCredentialsConfigMixin property readonly

Returns the GoogleCredentialsConfigMixin config.

Returns:

Type Description
GoogleCredentialsConfigMixin

The configuration.

image_builders special

Initialization for the GCP image builder.

gcp_image_builder

Google Cloud Builder image builder implementation.

GCPImageBuilder (BaseImageBuilder, GoogleCredentialsMixin)

Google Cloud Builder image builder implementation.

Source code in zenml/integrations/gcp/image_builders/gcp_image_builder.py
class GCPImageBuilder(BaseImageBuilder, GoogleCredentialsMixin):
    """Google Cloud Builder image builder implementation."""

    @property
    def config(self) -> GCPImageBuilderConfig:
        """The stack component configuration.

        Returns:
            The configuration.
        """
        return cast(GCPImageBuilderConfig, self._config)

    @property
    def is_building_locally(self) -> bool:
        """Whether the image builder builds the images on the client machine.

        Returns:
            True if the image builder builds locally, False otherwise.
        """
        return False

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Validates the stack for the GCP Image Builder.

        The GCP Image Builder requires a remote container registry to push the
        image to, and a GCP Artifact Store to upload the build context, so
        Cloud Build can access it.

        Returns:
            Stack validator.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            assert stack.container_registry

            if (
                stack.container_registry.flavor
                != ContainerRegistryFlavor.GCP.value
            ):
                return False, (
                    "The GCP Image Builder requires a GCP container registry to "
                    "push the image to. Please update your stack to include a "
                    "GCP container registry and try again."
                )

            if stack.artifact_store.flavor != GCP_ARTIFACT_STORE_FLAVOR:
                return False, (
                    "The GCP Image Builder requires a GCP Artifact Store to "
                    "upload the build context, so Cloud Build can access it."
                    "Please update your stack to include a GCP Artifact Store "
                    "and try again."
                )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_remote_components,
        )

    def build(
        self,
        image_name: str,
        build_context: "BuildContext",
        docker_build_options: Dict[str, Any],
        container_registry: Optional["BaseContainerRegistry"] = None,
    ) -> str:
        """Builds and pushes a Docker image.

        Args:
            image_name: Name of the image to build and push.
            build_context: The build context to use for the image.
            docker_build_options: Docker build options.
            container_registry: Optional container registry to push to.

        Returns:
            The Docker image name with digest.

        Raises:
            RuntimeError: If no container registry is passed.
            RuntimeError: If the Cloud Build build fails.
        """
        if not container_registry:
            raise RuntimeError(
                "The GCP Image Builder requires a container registry to push "
                "the image to. Please provide one and try again."
            )

        logger.info("Using Cloud Build to build image `%s`", image_name)
        cloud_build_context = self._upload_build_context(
            build_context=build_context,
            parent_path_directory_name="cloud-build-contexts",
        )
        build = self._configure_cloud_build(
            image_name=image_name,
            cloud_build_context=cloud_build_context,
            build_options=docker_build_options,
        )
        image_digest = self._run_cloud_build(build=build)
        image_name_without_tag, _ = image_name.rsplit(":", 1)
        image_name_with_digest = f"{image_name_without_tag}@{image_digest}"
        return image_name_with_digest

    def _configure_cloud_build(
        self,
        image_name: str,
        cloud_build_context: str,
        build_options: Dict[str, Any],
    ) -> cloudbuild_v1.Build:
        """Configures the build to be run to generate the Docker image.

        Args:
            image_name: The name of the image to build.
            cloud_build_context: The path to the build context.
            build_options: Docker build options.

        Returns:
            The build to run.
        """
        url_parts = urlparse(cloud_build_context)
        bucket = url_parts.netloc
        object_path = url_parts.path.lstrip("/")
        logger.info(
            "Build context located in bucket `%s` and object path `%s`",
            bucket,
            object_path,
        )

        cloud_builder_image = self.config.cloud_builder_image
        cloud_builder_network_option = f"--network={self.config.network}"
        logger.info(
            "Using Cloud Builder image `%s` to run the steps in the build. "
            "Container will be attached to network using option `%s`.",
            cloud_builder_image,
            cloud_builder_network_option,
        )

        # Convert the docker_build_options dictionary to a list of strings
        docker_build_args = []
        for key, value in build_options.items():
            option = f"--{key}"
            if isinstance(value, list):
                for val in value:
                    docker_build_args.extend([option, val])
            elif value is not None and not isinstance(value, bool):
                docker_build_args.extend([option, value])
            elif value is not False:
                docker_build_args.extend([option])

        return cloudbuild_v1.Build(
            source=cloudbuild_v1.Source(
                storage_source=cloudbuild_v1.StorageSource(
                    bucket=bucket, object=object_path
                ),
            ),
            steps=[
                {
                    "name": cloud_builder_image,
                    "args": [
                        "build",
                        cloud_builder_network_option,
                        "-t",
                        image_name,
                        ".",
                        *docker_build_args,
                    ],
                },
                {
                    "name": cloud_builder_image,
                    "args": ["push", image_name],
                },
            ],
            images=[image_name],
            timeout=f"{self.config.build_timeout}s",
        )

    def _run_cloud_build(self, build: cloudbuild_v1.Build) -> str:
        """Executes the Cloud Build run to build the Docker image.

        Args:
            build: The build to run.

        Returns:
            The Docker image repo digest.

        Raises:
            RuntimeError: If the Cloud Build run has failed.
        """
        credentials, project_id = self._get_authentication()
        client = cloudbuild_v1.CloudBuildClient(credentials=credentials)

        operation = client.create_build(project_id=project_id, build=build)
        log_url = operation.metadata.build.log_url
        logger.info(
            "Running Cloud Build to build the Docker image. Cloud Build logs: `%s`",
            log_url,
        )

        result = operation.result(timeout=self.config.build_timeout)

        if result.status != cloudbuild_v1.Build.Status.SUCCESS:
            raise RuntimeError(
                f"The Cloud Build run to build the Docker image has failed. More "
                f"information can be found in the Cloud Build logs: {log_url}."
            )

        logger.info(
            f"The Docker image has been built successfully. More information can "
            f"be found in the Cloud Build logs: `{log_url}`."
        )

        image_digest: str = result.results.images[0].digest
        return image_digest
config: GCPImageBuilderConfig property readonly

The stack component configuration.

Returns:

Type Description
GCPImageBuilderConfig

The configuration.

is_building_locally: bool property readonly

Whether the image builder builds the images on the client machine.

Returns:

Type Description
bool

True if the image builder builds locally, False otherwise.

validator: Optional[StackValidator] property readonly

Validates the stack for the GCP Image Builder.

The GCP Image Builder requires a remote container registry to push the image to, and a GCP Artifact Store to upload the build context, so Cloud Build can access it.

Returns:

Type Description
Optional[StackValidator]

Stack validator.

build(self, image_name, build_context, docker_build_options, container_registry=None)

Builds and pushes a Docker image.

Parameters:

Name Type Description Default
image_name str

Name of the image to build and push.

required
build_context BuildContext

The build context to use for the image.

required
docker_build_options Dict[str, Any]

Docker build options.

required
container_registry Optional[BaseContainerRegistry]

Optional container registry to push to.

None

Returns:

Type Description
str

The Docker image name with digest.

Exceptions:

Type Description
RuntimeError

If no container registry is passed.

RuntimeError

If the Cloud Build build fails.

Source code in zenml/integrations/gcp/image_builders/gcp_image_builder.py
def build(
    self,
    image_name: str,
    build_context: "BuildContext",
    docker_build_options: Dict[str, Any],
    container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
    """Builds and pushes a Docker image.

    Args:
        image_name: Name of the image to build and push.
        build_context: The build context to use for the image.
        docker_build_options: Docker build options.
        container_registry: Optional container registry to push to.

    Returns:
        The Docker image name with digest.

    Raises:
        RuntimeError: If no container registry is passed.
        RuntimeError: If the Cloud Build build fails.
    """
    if not container_registry:
        raise RuntimeError(
            "The GCP Image Builder requires a container registry to push "
            "the image to. Please provide one and try again."
        )

    logger.info("Using Cloud Build to build image `%s`", image_name)
    cloud_build_context = self._upload_build_context(
        build_context=build_context,
        parent_path_directory_name="cloud-build-contexts",
    )
    build = self._configure_cloud_build(
        image_name=image_name,
        cloud_build_context=cloud_build_context,
        build_options=docker_build_options,
    )
    image_digest = self._run_cloud_build(build=build)
    image_name_without_tag, _ = image_name.rsplit(":", 1)
    image_name_with_digest = f"{image_name_without_tag}@{image_digest}"
    return image_name_with_digest

orchestrators special

Initialization for the VertexAI orchestrator.

vertex_orchestrator

Implementation of the VertexAI orchestrator.

VertexOrchestrator (ContainerizedOrchestrator, GoogleCredentialsMixin)

Orchestrator responsible for running pipelines on Vertex AI.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
    """Orchestrator responsible for running pipelines on Vertex AI."""

    _pipeline_root: str

    @property
    def config(self) -> VertexOrchestratorConfig:
        """Returns the `VertexOrchestratorConfig` config.

        Returns:
            The configuration.
        """
        return cast(VertexOrchestratorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Vertex orchestrator.

        Returns:
            The settings class.
        """
        return VertexOrchestratorSettings

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Also validates that the artifact store is not local.

        Returns:
            A StackValidator instance.
        """

        def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that all the stack components are not local.

            Args:
                stack: The stack to validate.

            Returns:
                A tuple of (is_valid, error_message).
            """
            # Validate that the container registry is not local.
            container_registry = stack.container_registry
            if container_registry and container_registry.config.is_local:
                return False, (
                    f"The Vertex orchestrator does not support local "
                    f"container registries. You should replace the component '"
                    f"{container_registry.name}' "
                    f"{container_registry.type.value} to a remote one."
                )

            # Validate that the rest of the components are not local.
            for stack_comp in stack.components.values():
                # For Forward compatibility a list of components is returned,
                # but only the first item is relevant for now
                # TODO: [server] make sure the ComponentModel actually has
                #  a local_path property or implement similar check
                local_path = stack_comp.local_path
                if not local_path:
                    continue
                return False, (
                    f"The '{stack_comp.name}' {stack_comp.type.value} is a "
                    f"local stack component. The Vertex AI Pipelines "
                    f"orchestrator requires that all the components in the "
                    f"stack used to execute the pipeline have to be not local, "
                    f"because there is no way for Vertex to connect to your "
                    f"local machine. You should use a flavor of "
                    f"{stack_comp.type.value} other than '"
                    f"{stack_comp.flavor}'."
                )

            # If the `pipeline_root` has not been defined in the orchestrator
            # configuration, and the artifact store is not a GCP artifact store,
            # then raise an error.
            if (
                not self.config.pipeline_root
                and stack.artifact_store.flavor != GCP_ARTIFACT_STORE_FLAVOR
            ):
                return False, (
                    f"The attribute `pipeline_root` has not been set and it "
                    f"cannot be generated using the path of the artifact store "
                    f"because it is not a "
                    f"`zenml.integrations.gcp.artifact_store.GCPArtifactStore`."
                    f" To solve this issue, set the `pipeline_root` attribute "
                    f"manually executing the following command: "
                    f"`zenml orchestrator update {stack.orchestrator.name} "
                    f'--pipeline_root="<Cloud Storage URI>"`.'
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_stack_requirements,
        )

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for files for this orchestrator.

        Returns:
            The path to the root directory for all files concerning this
            orchestrator.
        """
        return os.path.join(
            get_global_config_directory(), "vertex", str(self.id)
        )

    @property
    def pipeline_directory(self) -> str:
        """Returns path to directory where kubeflow pipelines files are stored.

        Returns:
            Path to the pipeline directory.
        """
        return os.path.join(self.root_directory, "pipelines")

    def prepare_pipeline_deployment(
        self,
        deployment: "PipelineDeploymentResponse",
        stack: "Stack",
    ) -> None:
        """Build a Docker image and push it to the container registry.

        Args:
            deployment: The pipeline deployment configuration.
            stack: The stack on which the pipeline will be deployed.

        Raises:
            ValueError: If `cron_expression` is not in passed Schedule.
        """
        if deployment.schedule:
            if (
                deployment.schedule.catchup
                or deployment.schedule.interval_second
            ):
                logger.warning(
                    "Vertex orchestrator only uses schedules with the "
                    "`cron_expression` property, with optional `start_time` "
                    "and/or `end_time`. All other properties are ignored."
                )
            if deployment.schedule.cron_expression is None:
                raise ValueError(
                    "Property `cron_expression` must be set when passing "
                    "schedule to a Vertex orchestrator."
                )

    def _create_dynamic_component(
        self,
        image: str,
        command: List[str],
        arguments: List[str],
        component_name: str,
    ) -> dsl.PipelineTask:
        """Creates a dynamic container component for a Vertex pipeline.

        Args:
            image: The image to use for the component.
            command: The command to use for the component.
            arguments: The arguments to use for the component.
            component_name: The name of the component.

        Returns:
            The dynamic container component.
        """

        def dynamic_container_component() -> dsl.ContainerSpec:
            """Dynamic container component.

            Returns:
                The dynamic container component.
            """
            return dsl.ContainerSpec(
                image=image,
                command=command,
                args=arguments,
            )

        # Change the name of the function
        new_container_spec_func = types.FunctionType(
            dynamic_container_component.__code__,
            dynamic_container_component.__globals__,
            name=component_name,
            argdefs=dynamic_container_component.__defaults__,
            closure=dynamic_container_component.__closure__,
        )
        pipeline_task = dsl.container_component(new_container_spec_func)

        return pipeline_task

    def prepare_or_run_pipeline(
        self,
        deployment: "PipelineDeploymentResponse",
        stack: "Stack",
        environment: Dict[str, str],
    ) -> Iterator[Dict[str, MetadataType]]:
        """Creates a KFP JSON pipeline.

        # noqa: DAR402

        This is an intermediary representation of the pipeline which is then
        deployed to Vertex AI Pipelines service.

        How it works:
        -------------
        Before this method is called the `prepare_pipeline_deployment()` method
        builds a Docker image that contains the code for the pipeline, all steps
        the context around these files.

        Based on this Docker image a callable is created which builds
        container_ops for each step (`_construct_kfp_pipeline`). The function
        `kfp.components.load_component_from_text` is used to create the
        `ContainerOp`, because using the `dsl.ContainerOp` class directly is
        deprecated when using the Kubeflow SDK v2. The step entrypoint command
        with the entrypoint arguments is the command that will be executed by
        the container created using the previously created Docker image.

        This callable is then compiled into a JSON file that is used as the
        intermediary representation of the Kubeflow pipeline.

        This file then is submitted to the Vertex AI Pipelines service for
        execution.

        Args:
            deployment: The pipeline deployment to prepare or run.
            stack: The stack the pipeline will run on.
            environment: Environment variables to set in the orchestration
                environment.

        Raises:
            ValueError: If the attribute `pipeline_root` is not set, and it
                can be not generated using the path of the artifact store in the
                stack because it is not a
                `zenml.integrations.gcp.artifact_store.GCPArtifactStore`. Also gets
                raised if attempting to schedule pipeline run without using the
                `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.

        Yields:
            A dictionary of metadata related to the pipeline run.
        """
        orchestrator_run_name = get_orchestrator_run_name(
            pipeline_name=deployment.pipeline_configuration.name
        )
        # If the `pipeline_root` has not been defined in the orchestrator
        # configuration,
        # try to create it from the artifact store if it is a
        # `GCPArtifactStore`.
        if not self.config.pipeline_root:
            artifact_store = stack.artifact_store
            self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{deployment.pipeline_configuration.name}/{orchestrator_run_name}"
            logger.info(
                "The attribute `pipeline_root` has not been set in the "
                "orchestrator configuration. One has been generated "
                "automatically based on the path of the `GCPArtifactStore` "
                "artifact store in the stack used to execute the pipeline. "
                "The generated `pipeline_root` is `%s`.",
                self._pipeline_root,
            )
        else:
            self._pipeline_root = self.config.pipeline_root

        def _create_dynamic_pipeline() -> Any:
            """Create a dynamic pipeline including each step.

            Returns:
                pipeline_func
            """
            step_name_to_dynamic_component: Dict[str, Any] = {}

            for step_name, step in deployment.step_configurations.items():
                image = self.get_image(
                    deployment=deployment,
                    step_name=step_name,
                )
                command = StepEntrypointConfiguration.get_entrypoint_command()
                arguments = (
                    StepEntrypointConfiguration.get_entrypoint_arguments(
                        step_name=step_name,
                        deployment_id=deployment.id,
                    )
                )
                dynamic_component = self._create_dynamic_component(
                    image, command, arguments, step_name
                )
                step_settings = cast(
                    VertexOrchestratorSettings, self.get_settings(step)
                )
                pod_settings = step_settings.pod_settings
                if pod_settings:
                    if pod_settings.host_ipc:
                        logger.warning(
                            "Host IPC is set to `True` but not supported in "
                            "this orchestrator. Ignoring..."
                        )
                    if pod_settings.affinity:
                        logger.warning(
                            "Affinity is set but not supported in Vertex with "
                            "Kubeflow Pipelines 2.x. Ignoring..."
                        )
                    if pod_settings.tolerations:
                        logger.warning(
                            "Tolerations are set but not supported in "
                            "Vertex with Kubeflow Pipelines 2.x. Ignoring..."
                        )
                    if pod_settings.volumes:
                        logger.warning(
                            "Volumes are set but not supported in Vertex with "
                            "Kubeflow Pipelines 2.x. Ignoring..."
                        )
                    if pod_settings.volume_mounts:
                        logger.warning(
                            "Volume mounts are set but not supported in "
                            "Vertex with Kubeflow Pipelines 2.x. Ignoring..."
                        )
                    for key in pod_settings.node_selectors:
                        if (
                            key
                            != GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
                        ):
                            logger.warning(
                                "Vertex only allows the %s node selector, "
                                "ignoring the node selector %s.",
                                GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
                                key,
                            )

                step_name_to_dynamic_component[step_name] = dynamic_component

            @dsl.pipeline(  # type: ignore[misc]
                display_name=orchestrator_run_name,
            )
            def dynamic_pipeline() -> None:
                """Dynamic pipeline."""
                # iterate through the components one by one
                # (from step_name_to_dynamic_component)
                for (
                    component_name,
                    component,
                ) in step_name_to_dynamic_component.items():
                    # for each component, check to see what other steps are
                    # upstream of it
                    step = deployment.step_configurations[component_name]
                    upstream_step_components = [
                        step_name_to_dynamic_component[upstream_step_name]
                        for upstream_step_name in step.spec.upstream_steps
                    ]
                    task = (
                        component()
                        .set_display_name(
                            name=component_name,
                        )
                        .set_caching_options(enable_caching=False)
                        .set_env_variable(
                            name=ENV_ZENML_VERTEX_RUN_ID,
                            value=dsl.PIPELINE_JOB_NAME_PLACEHOLDER,
                        )
                        .after(*upstream_step_components)
                    )

                    step_settings = cast(
                        VertexOrchestratorSettings, self.get_settings(step)
                    )
                    pod_settings = step_settings.pod_settings

                    node_selector_constraint: Optional[Tuple[str, str]] = None
                    if pod_settings and (
                        GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
                        in pod_settings.node_selectors.keys()
                    ):
                        node_selector_constraint = (
                            GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
                            pod_settings.node_selectors[
                                GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
                            ],
                        )
                    elif step_settings.node_selector_constraint:
                        node_selector_constraint = (
                            GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
                            step_settings.node_selector_constraint[1],
                        )

                    self._configure_container_resources(
                        dynamic_component=task,
                        resource_settings=step.config.resource_settings,
                        node_selector_constraint=node_selector_constraint,
                    )

            return dynamic_pipeline

        def _update_json_with_environment(
            yaml_file_path: str, environment: Dict[str, str]
        ) -> None:
            """Updates the env section of the steps in the YAML file with the given environment variables.

            Args:
                yaml_file_path: The path to the YAML file to update.
                environment: A dictionary of environment variables to add.
            """
            pipeline_definition = yaml_utils.read_json(pipeline_file_path)

            # Iterate through each component and add the environment variables
            for executor in pipeline_definition["deploymentSpec"]["executors"]:
                if (
                    "container"
                    in pipeline_definition["deploymentSpec"]["executors"][
                        executor
                    ]
                ):
                    container = pipeline_definition["deploymentSpec"][
                        "executors"
                    ][executor]["container"]
                    if "env" not in container:
                        container["env"] = []
                    for key, value in environment.items():
                        container["env"].append({"name": key, "value": value})

            yaml_utils.write_json(pipeline_file_path, pipeline_definition)

            print(
                f"Updated YAML file with environment variables at {yaml_file_path}"
            )

        # Save the generated pipeline to a file.
        fileio.makedirs(self.pipeline_directory)
        pipeline_file_path = os.path.join(
            self.pipeline_directory,
            f"{orchestrator_run_name}.json",
        )

        # Compile the pipeline using the Kubeflow SDK V2 compiler that allows
        # to generate a JSON representation of the pipeline that can be later
        # upload to Vertex AI Pipelines service.
        Compiler().compile(
            pipeline_func=_create_dynamic_pipeline(),
            package_path=pipeline_file_path,
            pipeline_name=_clean_pipeline_name(
                deployment.pipeline_configuration.name
            ),
        )

        # Let's update the YAML file with the environment variables
        _update_json_with_environment(pipeline_file_path, environment)

        logger.info(
            "Writing Vertex workflow definition to `%s`.", pipeline_file_path
        )

        settings = cast(
            VertexOrchestratorSettings, self.get_settings(deployment)
        )

        # Using the Google Cloud AIPlatform client, upload and execute the
        # pipeline on the Vertex AI Pipelines service.
        if metadata := self._upload_and_run_pipeline(
            pipeline_name=deployment.pipeline_configuration.name,
            pipeline_file_path=pipeline_file_path,
            run_name=orchestrator_run_name,
            settings=settings,
            schedule=deployment.schedule,
        ):
            yield from metadata

    def _upload_and_run_pipeline(
        self,
        pipeline_name: str,
        pipeline_file_path: str,
        run_name: str,
        settings: VertexOrchestratorSettings,
        schedule: Optional["ScheduleResponse"] = None,
    ) -> Iterator[Dict[str, MetadataType]]:
        """Uploads and run the pipeline on the Vertex AI Pipelines service.

        Args:
            pipeline_name: Name of the pipeline.
            pipeline_file_path: Path of the JSON file containing the compiled
                Kubeflow pipeline (compiled with Kubeflow SDK v2).
            run_name: Orchestrator run name.
            settings: Pipeline level settings for this orchestrator.
            schedule: The schedule the pipeline will run on.

        Raises:
            RuntimeError: If the Vertex Orchestrator fails to provision or any
                other Runtime errors.

        Yields:
            A dictionary of metadata related to the pipeline run.
        """
        # We have to replace the hyphens in the run name with underscores
        # and lower case the string, because the Vertex AI Pipelines service
        # requires this format.
        job_id = _clean_pipeline_name(run_name)

        # Get the credentials that would be used to create the Vertex AI
        # Pipelines job.
        credentials, project_id = self._get_authentication()

        # Instantiate the Vertex AI Pipelines job
        run = aiplatform.PipelineJob(
            display_name=pipeline_name,
            template_path=pipeline_file_path,
            job_id=job_id,
            pipeline_root=self._pipeline_root,
            parameter_values=None,
            enable_caching=False,
            encryption_spec_key_name=self.config.encryption_spec_key_name,
            labels=settings.labels,
            credentials=credentials,
            project=project_id,
            location=self.config.location,
        )

        if self.config.workload_service_account:
            logger.info(
                "The Vertex AI Pipelines job workload will be executed "
                "using the `%s` "
                "service account.",
                self.config.workload_service_account,
            )
        if self.config.network:
            logger.info(
                "The Vertex AI Pipelines job will be peered with the `%s` "
                "network.",
                self.config.network,
            )

        try:
            if schedule:
                logger.info(
                    "Scheduling job using native Vertex AI Pipelines "
                    "scheduling..."
                )
                run.create_schedule(
                    display_name=schedule.name,
                    cron=schedule.cron_expression,
                    start_time=schedule.utc_start_time,
                    end_time=schedule.utc_end_time,
                    service_account=self.config.workload_service_account,
                    network=self.config.network,
                )

            else:
                logger.info(
                    "No schedule detected. Creating one-off Vertex job..."
                )
                logger.info(
                    "Submitting pipeline job with job_id `%s` to Vertex AI "
                    "Pipelines service.",
                    job_id,
                )

                # Submit the job to Vertex AI Pipelines service.
                run.submit(
                    service_account=self.config.workload_service_account,
                    network=self.config.network,
                )
                logger.info(
                    "View the Vertex AI Pipelines job at %s",
                    run._dashboard_uri(),
                )

                # Yield metadata based on the generated job object
                yield from self.compute_metadata(run)

                if settings.synchronous:
                    logger.info(
                        "Waiting for the Vertex AI Pipelines job to finish..."
                    )
                    run.wait()

        except google_exceptions.ClientError as e:
            logger.error("Failed to create the Vertex AI Pipelines job: %s", e)
            raise RuntimeError(
                f"Failed to create the Vertex AI Pipelines job: {e}"
            )
        except RuntimeError as e:
            logger.error(
                "The Vertex AI Pipelines job execution has failed: %s", e
            )
            raise

    def get_orchestrator_run_id(self) -> str:
        """Returns the active orchestrator run id.

        Raises:
            RuntimeError: If the environment variable specifying the run id
                is not set.

        Returns:
            The orchestrator run id.
        """
        try:
            return os.environ[ENV_ZENML_VERTEX_RUN_ID]
        except KeyError:
            raise RuntimeError(
                "Unable to read run id from environment variable "
                f"{ENV_ZENML_VERTEX_RUN_ID}."
            )

    def get_pipeline_run_metadata(
        self, run_id: UUID
    ) -> Dict[str, "MetadataType"]:
        """Get general component-specific metadata for a pipeline run.

        Args:
            run_id: The ID of the pipeline run.

        Returns:
            A dictionary of metadata.
        """
        run_url = (
            f"https://console.cloud.google.com/vertex-ai/locations/"
            f"{self.config.location}/pipelines/runs/"
            f"{self.get_orchestrator_run_id()}"
        )
        if self.config.project:
            run_url += f"?project={self.config.project}"
        return {
            METADATA_ORCHESTRATOR_URL: Uri(run_url),
        }

    def _configure_container_resources(
        self,
        dynamic_component: dsl.PipelineTask,
        resource_settings: "ResourceSettings",
        node_selector_constraint: Optional[Tuple[str, str]] = None,
    ) -> dsl.PipelineTask:
        """Adds resource requirements to the container.

        Args:
            dynamic_component: The dynamic component to add the resource
                settings to.
            resource_settings: The resource settings to use for this
                container.
            node_selector_constraint: Node selector constraint to apply to
                the container.

        Returns:
            The dynamic component with the resource settings applied.
        """
        # Set optional CPU, RAM and GPU constraints for the pipeline
        cpu_limit = None
        if resource_settings:
            cpu_limit = resource_settings.cpu_count or self.config.cpu_limit

        if cpu_limit is not None:
            dynamic_component = dynamic_component.set_cpu_limit(str(cpu_limit))

        memory_limit = (
            resource_settings.memory[:-1]
            if resource_settings.memory
            else self.config.memory_limit
        )
        if memory_limit is not None:
            dynamic_component = dynamic_component.set_memory_limit(
                memory_limit
            )

        gpu_limit = (
            resource_settings.gpu_count
            if resource_settings.gpu_count is not None
            else self.config.gpu_limit
        )

        if node_selector_constraint:
            _, value = node_selector_constraint
            if gpu_limit is not None and gpu_limit > 0:
                dynamic_component = (
                    dynamic_component.set_accelerator_type(value)
                    .set_accelerator_limit(gpu_limit)
                    .set_gpu_limit(gpu_limit)
                )
            else:
                logger.warning(
                    "Accelerator type %s specified, but the GPU limit is not "
                    "set or set to 0. The accelerator type will be ignored. "
                    "To fix this warning, either remove the specified "
                    "accelerator type or set the `gpu_count` using the "
                    "ResourceSettings (https://docs.zenml.io/how-to/training-with-gpus#specify-resource-requirements-for-steps)."
                )

        return dynamic_component

    def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
        """Refreshes the status of a specific pipeline run.

        Args:
            run: The run that was executed by this orchestrator.

        Returns:
            the actual status of the pipeline job.

        Raises:
            AssertionError: If the run was not executed by to this orchestrator.
            ValueError: If it fetches an unknown state or if we can not fetch
                the orchestrator run ID.
        """
        # Make sure that the stack exists and is accessible
        if run.stack is None:
            raise ValueError(
                "The stack that the run was executed on is not available "
                "anymore."
            )

        # Make sure that the run belongs to this orchestrator
        assert (
            self.id
            == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
        )

        # Initialize the Vertex client
        credentials, project_id = self._get_authentication()
        aiplatform.init(
            project=project_id,
            location=self.config.location,
            credentials=credentials,
        )

        # Fetch the status of the PipelineJob
        if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
            run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
        elif run.orchestrator_run_id is not None:
            run_id = run.orchestrator_run_id
        else:
            raise ValueError(
                "Can not find the orchestrator run ID, thus can not fetch "
                "the status."
            )
        status = aiplatform.PipelineJob.get(run_id).state

        # Map the potential outputs to ZenML ExecutionStatus. Potential values:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#
        if status in [PipelineState.PIPELINE_STATE_UNSPECIFIED]:
            return run.status
        elif status in [
            PipelineState.PIPELINE_STATE_QUEUED,
            PipelineState.PIPELINE_STATE_PENDING,
        ]:
            return ExecutionStatus.INITIALIZING
        elif status in [
            PipelineState.PIPELINE_STATE_RUNNING,
            PipelineState.PIPELINE_STATE_PAUSED,
        ]:
            return ExecutionStatus.RUNNING
        elif status in [PipelineState.PIPELINE_STATE_SUCCEEDED]:
            return ExecutionStatus.COMPLETED

        elif status in [
            PipelineState.PIPELINE_STATE_FAILED,
            PipelineState.PIPELINE_STATE_CANCELLING,
            PipelineState.PIPELINE_STATE_CANCELLED,
        ]:
            return ExecutionStatus.FAILED
        else:
            raise ValueError("Unknown status for the pipeline job.")

    def compute_metadata(
        self, job: aiplatform.PipelineJob
    ) -> Iterator[Dict[str, MetadataType]]:
        """Generate run metadata based on the corresponding Vertex PipelineJob.

        Args:
            job: The corresponding PipelineJob object.

        Yields:
            A dictionary of metadata related to the pipeline run.
        """
        metadata: Dict[str, MetadataType] = {}

        # Orchestrator Run ID
        if run_id := self._compute_orchestrator_run_id(job):
            metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id

        # URL to the Vertex's pipeline view
        if orchestrator_url := self._compute_orchestrator_url(job):
            metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

        # URL to the corresponding Logs Explorer page
        if logs_url := self._compute_orchestrator_logs_url(job):
            metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)

        yield metadata

    @staticmethod
    def _compute_orchestrator_url(
        job: aiplatform.PipelineJob,
    ) -> Optional[str]:
        """Generate the Orchestrator Dashboard URL upon pipeline execution.

        Args:
            job: The corresponding PipelineJob object.

        Returns:
             the URL to the dashboard view in Vertex.
        """
        try:
            return str(job._dashboard_uri())
        except Exception as e:
            logger.warning(
                f"There was an issue while extracting the pipeline url: {e}"
            )
            return None

    @staticmethod
    def _compute_orchestrator_logs_url(
        job: aiplatform.PipelineJob,
    ) -> Optional[str]:
        """Generate the Logs Explorer URL upon pipeline execution.

        Args:
            job: The corresponding PipelineJob object.

        Returns:
            the URL querying the pipeline logs in Logs Explorer on GCP.
        """
        try:
            base_url = "https://console.cloud.google.com/logs/query"
            query = f"""
             resource.type="aiplatform.googleapis.com/PipelineJob"
             resource.labels.pipeline_job_id="{job.job_id}"
             """
            encoded_query = urllib.parse.quote(query)
            return f"{base_url}?project={job.project}&query={encoded_query}"

        except Exception as e:
            logger.warning(
                f"There was an issue while extracting the logs url: {e}"
            )
            return None

    @staticmethod
    def _compute_orchestrator_run_id(
        job: aiplatform.PipelineJob,
    ) -> Optional[str]:
        """Fetch the Orchestrator Run ID upon pipeline execution.

        Args:
            job: The corresponding PipelineJob object.

        Returns:
            the Execution ID of the run in Vertex.
        """
        try:
            if job.job_id:
                return str(job.job_id)

            return None
        except Exception as e:
            logger.warning(
                f"There was an issue while extracting the pipeline run ID: {e}"
            )
            return None
config: VertexOrchestratorConfig property readonly

Returns the VertexOrchestratorConfig config.

Returns:

Type Description
VertexOrchestratorConfig

The configuration.

pipeline_directory: str property readonly

Returns path to directory where kubeflow pipelines files are stored.

Returns:

Type Description
str

Path to the pipeline directory.

root_directory: str property readonly

Returns path to the root directory for files for this orchestrator.

Returns:

Type Description
str

The path to the root directory for all files concerning this orchestrator.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Vertex orchestrator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Also validates that the artifact store is not local.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

compute_metadata(self, job)

Generate run metadata based on the corresponding Vertex PipelineJob.

Parameters:

Name Type Description Default
job google.cloud.aiplatform.PipelineJob

The corresponding PipelineJob object.

required

Yields:

Type Description
Iterator[Dict[str, Union[str, int, float, bool, Dict[Any, Any], List[Any], Set[Any], Tuple[Any, ...], zenml.metadata.metadata_types.Uri, zenml.metadata.metadata_types.Path, zenml.metadata.metadata_types.DType, zenml.metadata.metadata_types.StorageSize]]]

A dictionary of metadata related to the pipeline run.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def compute_metadata(
    self, job: aiplatform.PipelineJob
) -> Iterator[Dict[str, MetadataType]]:
    """Generate run metadata based on the corresponding Vertex PipelineJob.

    Args:
        job: The corresponding PipelineJob object.

    Yields:
        A dictionary of metadata related to the pipeline run.
    """
    metadata: Dict[str, MetadataType] = {}

    # Orchestrator Run ID
    if run_id := self._compute_orchestrator_run_id(job):
        metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id

    # URL to the Vertex's pipeline view
    if orchestrator_url := self._compute_orchestrator_url(job):
        metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

    # URL to the corresponding Logs Explorer page
    if logs_url := self._compute_orchestrator_logs_url(job):
        metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)

    yield metadata
fetch_status(self, run)

Refreshes the status of a specific pipeline run.

Parameters:

Name Type Description Default
run PipelineRunResponse

The run that was executed by this orchestrator.

required

Returns:

Type Description
ExecutionStatus

the actual status of the pipeline job.

Exceptions:

Type Description
AssertionError

If the run was not executed by to this orchestrator.

ValueError

If it fetches an unknown state or if we can not fetch the orchestrator run ID.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
    """Refreshes the status of a specific pipeline run.

    Args:
        run: The run that was executed by this orchestrator.

    Returns:
        the actual status of the pipeline job.

    Raises:
        AssertionError: If the run was not executed by to this orchestrator.
        ValueError: If it fetches an unknown state or if we can not fetch
            the orchestrator run ID.
    """
    # Make sure that the stack exists and is accessible
    if run.stack is None:
        raise ValueError(
            "The stack that the run was executed on is not available "
            "anymore."
        )

    # Make sure that the run belongs to this orchestrator
    assert (
        self.id
        == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
    )

    # Initialize the Vertex client
    credentials, project_id = self._get_authentication()
    aiplatform.init(
        project=project_id,
        location=self.config.location,
        credentials=credentials,
    )

    # Fetch the status of the PipelineJob
    if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
        run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
    elif run.orchestrator_run_id is not None:
        run_id = run.orchestrator_run_id
    else:
        raise ValueError(
            "Can not find the orchestrator run ID, thus can not fetch "
            "the status."
        )
    status = aiplatform.PipelineJob.get(run_id).state

    # Map the potential outputs to ZenML ExecutionStatus. Potential values:
    # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#
    if status in [PipelineState.PIPELINE_STATE_UNSPECIFIED]:
        return run.status
    elif status in [
        PipelineState.PIPELINE_STATE_QUEUED,
        PipelineState.PIPELINE_STATE_PENDING,
    ]:
        return ExecutionStatus.INITIALIZING
    elif status in [
        PipelineState.PIPELINE_STATE_RUNNING,
        PipelineState.PIPELINE_STATE_PAUSED,
    ]:
        return ExecutionStatus.RUNNING
    elif status in [PipelineState.PIPELINE_STATE_SUCCEEDED]:
        return ExecutionStatus.COMPLETED

    elif status in [
        PipelineState.PIPELINE_STATE_FAILED,
        PipelineState.PIPELINE_STATE_CANCELLING,
        PipelineState.PIPELINE_STATE_CANCELLED,
    ]:
        return ExecutionStatus.FAILED
    else:
        raise ValueError("Unknown status for the pipeline job.")
get_orchestrator_run_id(self)

Returns the active orchestrator run id.

Exceptions:

Type Description
RuntimeError

If the environment variable specifying the run id is not set.

Returns:

Type Description
str

The orchestrator run id.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_orchestrator_run_id(self) -> str:
    """Returns the active orchestrator run id.

    Raises:
        RuntimeError: If the environment variable specifying the run id
            is not set.

    Returns:
        The orchestrator run id.
    """
    try:
        return os.environ[ENV_ZENML_VERTEX_RUN_ID]
    except KeyError:
        raise RuntimeError(
            "Unable to read run id from environment variable "
            f"{ENV_ZENML_VERTEX_RUN_ID}."
        )
get_pipeline_run_metadata(self, run_id)

Get general component-specific metadata for a pipeline run.

Parameters:

Name Type Description Default
run_id UUID

The ID of the pipeline run.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_pipeline_run_metadata(
    self, run_id: UUID
) -> Dict[str, "MetadataType"]:
    """Get general component-specific metadata for a pipeline run.

    Args:
        run_id: The ID of the pipeline run.

    Returns:
        A dictionary of metadata.
    """
    run_url = (
        f"https://console.cloud.google.com/vertex-ai/locations/"
        f"{self.config.location}/pipelines/runs/"
        f"{self.get_orchestrator_run_id()}"
    )
    if self.config.project:
        run_url += f"?project={self.config.project}"
    return {
        METADATA_ORCHESTRATOR_URL: Uri(run_url),
    }
prepare_or_run_pipeline(self, deployment, stack, environment)

Creates a KFP JSON pipeline.

noqa: DAR402

This is an intermediary representation of the pipeline which is then deployed to Vertex AI Pipelines service.

How it works:

Before this method is called the prepare_pipeline_deployment() method builds a Docker image that contains the code for the pipeline, all steps the context around these files.

Based on this Docker image a callable is created which builds container_ops for each step (_construct_kfp_pipeline). The function kfp.components.load_component_from_text is used to create the ContainerOp, because using the dsl.ContainerOp class directly is deprecated when using the Kubeflow SDK v2. The step entrypoint command with the entrypoint arguments is the command that will be executed by the container created using the previously created Docker image.

This callable is then compiled into a JSON file that is used as the intermediary representation of the Kubeflow pipeline.

This file then is submitted to the Vertex AI Pipelines service for execution.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponse

The pipeline deployment to prepare or run.

required
stack Stack

The stack the pipeline will run on.

required
environment Dict[str, str]

Environment variables to set in the orchestration environment.

required

Exceptions:

Type Description
ValueError

If the attribute pipeline_root is not set, and it can be not generated using the path of the artifact store in the stack because it is not a zenml.integrations.gcp.artifact_store.GCPArtifactStore. Also gets raised if attempting to schedule pipeline run without using the zenml.integrations.gcp.artifact_store.GCPArtifactStore.

Yields:

Type Description
Iterator[Dict[str, Union[str, int, float, bool, Dict[Any, Any], List[Any], Set[Any], Tuple[Any, ...], zenml.metadata.metadata_types.Uri, zenml.metadata.metadata_types.Path, zenml.metadata.metadata_types.DType, zenml.metadata.metadata_types.StorageSize]]]

A dictionary of metadata related to the pipeline run.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_or_run_pipeline(
    self,
    deployment: "PipelineDeploymentResponse",
    stack: "Stack",
    environment: Dict[str, str],
) -> Iterator[Dict[str, MetadataType]]:
    """Creates a KFP JSON pipeline.

    # noqa: DAR402

    This is an intermediary representation of the pipeline which is then
    deployed to Vertex AI Pipelines service.

    How it works:
    -------------
    Before this method is called the `prepare_pipeline_deployment()` method
    builds a Docker image that contains the code for the pipeline, all steps
    the context around these files.

    Based on this Docker image a callable is created which builds
    container_ops for each step (`_construct_kfp_pipeline`). The function
    `kfp.components.load_component_from_text` is used to create the
    `ContainerOp`, because using the `dsl.ContainerOp` class directly is
    deprecated when using the Kubeflow SDK v2. The step entrypoint command
    with the entrypoint arguments is the command that will be executed by
    the container created using the previously created Docker image.

    This callable is then compiled into a JSON file that is used as the
    intermediary representation of the Kubeflow pipeline.

    This file then is submitted to the Vertex AI Pipelines service for
    execution.

    Args:
        deployment: The pipeline deployment to prepare or run.
        stack: The stack the pipeline will run on.
        environment: Environment variables to set in the orchestration
            environment.

    Raises:
        ValueError: If the attribute `pipeline_root` is not set, and it
            can be not generated using the path of the artifact store in the
            stack because it is not a
            `zenml.integrations.gcp.artifact_store.GCPArtifactStore`. Also gets
            raised if attempting to schedule pipeline run without using the
            `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.

    Yields:
        A dictionary of metadata related to the pipeline run.
    """
    orchestrator_run_name = get_orchestrator_run_name(
        pipeline_name=deployment.pipeline_configuration.name
    )
    # If the `pipeline_root` has not been defined in the orchestrator
    # configuration,
    # try to create it from the artifact store if it is a
    # `GCPArtifactStore`.
    if not self.config.pipeline_root:
        artifact_store = stack.artifact_store
        self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{deployment.pipeline_configuration.name}/{orchestrator_run_name}"
        logger.info(
            "The attribute `pipeline_root` has not been set in the "
            "orchestrator configuration. One has been generated "
            "automatically based on the path of the `GCPArtifactStore` "
            "artifact store in the stack used to execute the pipeline. "
            "The generated `pipeline_root` is `%s`.",
            self._pipeline_root,
        )
    else:
        self._pipeline_root = self.config.pipeline_root

    def _create_dynamic_pipeline() -> Any:
        """Create a dynamic pipeline including each step.

        Returns:
            pipeline_func
        """
        step_name_to_dynamic_component: Dict[str, Any] = {}

        for step_name, step in deployment.step_configurations.items():
            image = self.get_image(
                deployment=deployment,
                step_name=step_name,
            )
            command = StepEntrypointConfiguration.get_entrypoint_command()
            arguments = (
                StepEntrypointConfiguration.get_entrypoint_arguments(
                    step_name=step_name,
                    deployment_id=deployment.id,
                )
            )
            dynamic_component = self._create_dynamic_component(
                image, command, arguments, step_name
            )
            step_settings = cast(
                VertexOrchestratorSettings, self.get_settings(step)
            )
            pod_settings = step_settings.pod_settings
            if pod_settings:
                if pod_settings.host_ipc:
                    logger.warning(
                        "Host IPC is set to `True` but not supported in "
                        "this orchestrator. Ignoring..."
                    )
                if pod_settings.affinity:
                    logger.warning(
                        "Affinity is set but not supported in Vertex with "
                        "Kubeflow Pipelines 2.x. Ignoring..."
                    )
                if pod_settings.tolerations:
                    logger.warning(
                        "Tolerations are set but not supported in "
                        "Vertex with Kubeflow Pipelines 2.x. Ignoring..."
                    )
                if pod_settings.volumes:
                    logger.warning(
                        "Volumes are set but not supported in Vertex with "
                        "Kubeflow Pipelines 2.x. Ignoring..."
                    )
                if pod_settings.volume_mounts:
                    logger.warning(
                        "Volume mounts are set but not supported in "
                        "Vertex with Kubeflow Pipelines 2.x. Ignoring..."
                    )
                for key in pod_settings.node_selectors:
                    if (
                        key
                        != GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
                    ):
                        logger.warning(
                            "Vertex only allows the %s node selector, "
                            "ignoring the node selector %s.",
                            GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
                            key,
                        )

            step_name_to_dynamic_component[step_name] = dynamic_component

        @dsl.pipeline(  # type: ignore[misc]
            display_name=orchestrator_run_name,
        )
        def dynamic_pipeline() -> None:
            """Dynamic pipeline."""
            # iterate through the components one by one
            # (from step_name_to_dynamic_component)
            for (
                component_name,
                component,
            ) in step_name_to_dynamic_component.items():
                # for each component, check to see what other steps are
                # upstream of it
                step = deployment.step_configurations[component_name]
                upstream_step_components = [
                    step_name_to_dynamic_component[upstream_step_name]
                    for upstream_step_name in step.spec.upstream_steps
                ]
                task = (
                    component()
                    .set_display_name(
                        name=component_name,
                    )
                    .set_caching_options(enable_caching=False)
                    .set_env_variable(
                        name=ENV_ZENML_VERTEX_RUN_ID,
                        value=dsl.PIPELINE_JOB_NAME_PLACEHOLDER,
                    )
                    .after(*upstream_step_components)
                )

                step_settings = cast(
                    VertexOrchestratorSettings, self.get_settings(step)
                )
                pod_settings = step_settings.pod_settings

                node_selector_constraint: Optional[Tuple[str, str]] = None
                if pod_settings and (
                    GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
                    in pod_settings.node_selectors.keys()
                ):
                    node_selector_constraint = (
                        GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
                        pod_settings.node_selectors[
                            GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
                        ],
                    )
                elif step_settings.node_selector_constraint:
                    node_selector_constraint = (
                        GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
                        step_settings.node_selector_constraint[1],
                    )

                self._configure_container_resources(
                    dynamic_component=task,
                    resource_settings=step.config.resource_settings,
                    node_selector_constraint=node_selector_constraint,
                )

        return dynamic_pipeline

    def _update_json_with_environment(
        yaml_file_path: str, environment: Dict[str, str]
    ) -> None:
        """Updates the env section of the steps in the YAML file with the given environment variables.

        Args:
            yaml_file_path: The path to the YAML file to update.
            environment: A dictionary of environment variables to add.
        """
        pipeline_definition = yaml_utils.read_json(pipeline_file_path)

        # Iterate through each component and add the environment variables
        for executor in pipeline_definition["deploymentSpec"]["executors"]:
            if (
                "container"
                in pipeline_definition["deploymentSpec"]["executors"][
                    executor
                ]
            ):
                container = pipeline_definition["deploymentSpec"][
                    "executors"
                ][executor]["container"]
                if "env" not in container:
                    container["env"] = []
                for key, value in environment.items():
                    container["env"].append({"name": key, "value": value})

        yaml_utils.write_json(pipeline_file_path, pipeline_definition)

        print(
            f"Updated YAML file with environment variables at {yaml_file_path}"
        )

    # Save the generated pipeline to a file.
    fileio.makedirs(self.pipeline_directory)
    pipeline_file_path = os.path.join(
        self.pipeline_directory,
        f"{orchestrator_run_name}.json",
    )

    # Compile the pipeline using the Kubeflow SDK V2 compiler that allows
    # to generate a JSON representation of the pipeline that can be later
    # upload to Vertex AI Pipelines service.
    Compiler().compile(
        pipeline_func=_create_dynamic_pipeline(),
        package_path=pipeline_file_path,
        pipeline_name=_clean_pipeline_name(
            deployment.pipeline_configuration.name
        ),
    )

    # Let's update the YAML file with the environment variables
    _update_json_with_environment(pipeline_file_path, environment)

    logger.info(
        "Writing Vertex workflow definition to `%s`.", pipeline_file_path
    )

    settings = cast(
        VertexOrchestratorSettings, self.get_settings(deployment)
    )

    # Using the Google Cloud AIPlatform client, upload and execute the
    # pipeline on the Vertex AI Pipelines service.
    if metadata := self._upload_and_run_pipeline(
        pipeline_name=deployment.pipeline_configuration.name,
        pipeline_file_path=pipeline_file_path,
        run_name=orchestrator_run_name,
        settings=settings,
        schedule=deployment.schedule,
    ):
        yield from metadata
prepare_pipeline_deployment(self, deployment, stack)

Build a Docker image and push it to the container registry.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponse

The pipeline deployment configuration.

required
stack Stack

The stack on which the pipeline will be deployed.

required

Exceptions:

Type Description
ValueError

If cron_expression is not in passed Schedule.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_pipeline_deployment(
    self,
    deployment: "PipelineDeploymentResponse",
    stack: "Stack",
) -> None:
    """Build a Docker image and push it to the container registry.

    Args:
        deployment: The pipeline deployment configuration.
        stack: The stack on which the pipeline will be deployed.

    Raises:
        ValueError: If `cron_expression` is not in passed Schedule.
    """
    if deployment.schedule:
        if (
            deployment.schedule.catchup
            or deployment.schedule.interval_second
        ):
            logger.warning(
                "Vertex orchestrator only uses schedules with the "
                "`cron_expression` property, with optional `start_time` "
                "and/or `end_time`. All other properties are ignored."
            )
        if deployment.schedule.cron_expression is None:
            raise ValueError(
                "Property `cron_expression` must be set when passing "
                "schedule to a Vertex orchestrator."
            )

service_connectors special

ZenML GCP Service Connector.

gcp_service_connector

GCP Service Connector.

The GCP Service Connector implements various authentication methods for GCP services:

  • Explicit GCP service account key
GCPAuthenticationMethods (StrEnum)

GCP Authentication methods.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPAuthenticationMethods(StrEnum):
    """GCP Authentication methods."""

    IMPLICIT = "implicit"
    USER_ACCOUNT = "user-account"
    SERVICE_ACCOUNT = "service-account"
    EXTERNAL_ACCOUNT = "external-account"
    OAUTH2_TOKEN = "oauth2-token"
    IMPERSONATION = "impersonation"
GCPBaseConfig (AuthenticationConfig)

GCP base configuration.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPBaseConfig(AuthenticationConfig):
    """GCP base configuration."""

    @property
    def gcp_project_id(self) -> str:
        """Get the GCP project ID.

        This method must be implemented by subclasses to ensure that the GCP
        project ID is always available.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError
gcp_project_id: str property readonly

Get the GCP project ID.

This method must be implemented by subclasses to ensure that the GCP project ID is always available.

Exceptions:

Type Description
NotImplementedError

If the method is not implemented.

GCPBaseProjectIDConfig (GCPBaseConfig)

GCP base configuration with included project ID.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPBaseProjectIDConfig(GCPBaseConfig):
    """GCP base configuration with included project ID."""

    project_id: str = Field(
        title="GCP Project ID where the target resource is located.",
    )

    @property
    def gcp_project_id(self) -> str:
        """Get the GCP project ID.

        Returns:
            The GCP project ID.
        """
        return self.project_id
gcp_project_id: str property readonly

Get the GCP project ID.

Returns:

Type Description
str

The GCP project ID.

GCPExternalAccountConfig (GCPBaseProjectIDConfig, GCPExternalAccountCredentials)

GCP external account configuration.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPExternalAccountConfig(
    GCPBaseProjectIDConfig, GCPExternalAccountCredentials
):
    """GCP external account configuration."""
GCPExternalAccountCredentials (AuthenticationConfig)

GCP external account credentials.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPExternalAccountCredentials(AuthenticationConfig):
    """GCP external account credentials."""

    external_account_json: PlainSerializedSecretStr = Field(
        title="GCP External Account JSON optionally base64 encoded.",
    )

    generate_temporary_tokens: bool = Field(
        default=True,
        title="Generate temporary OAuth 2.0 tokens",
        description="Whether to generate temporary OAuth 2.0 tokens from the "
        "external account key JSON. If set to False, the connector will "
        "distribute the external account JSON to clients instead.",
    )

    @model_validator(mode="before")
    @classmethod
    @before_validator_handler
    def validate_service_account_dict(
        cls, data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Convert the external account credentials to JSON if given in dict format.

        Args:
            data: The configuration values.

        Returns:
            The validated configuration values.

        Raises:
            ValueError: If the external account credentials JSON is invalid.
        """
        external_account_json = data.get("external_account_json")
        if isinstance(external_account_json, dict):
            data["external_account_json"] = json.dumps(
                data["external_account_json"]
            )
        elif isinstance(external_account_json, str):
            # Check if the external account JSON is base64 encoded and decode it
            if re.match(r"^[A-Za-z0-9+/=]+$", external_account_json):
                try:
                    data["external_account_json"] = base64.b64decode(
                        external_account_json
                    ).decode("utf-8")
                except Exception as e:
                    raise ValueError(
                        f"Failed to decode base64 encoded external account JSON: {e}"
                    )

        return data

    @field_validator("external_account_json")
    @classmethod
    def validate_external_account_json(
        cls, value: PlainSerializedSecretStr
    ) -> PlainSerializedSecretStr:
        """Validate the external account credentials JSON.

        Args:
            value: The external account credentials JSON.

        Returns:
            The validated external account credentials JSON.

        Raises:
            ValueError: If the external account credentials JSON is invalid.
        """
        try:
            external_account_info = json.loads(value.get_secret_value())
        except json.JSONDecodeError as e:
            raise ValueError(
                f"GCP external account credentials is not a valid JSON: {e}"
            )

        # Check that all fields are present
        required_fields = [
            "type",
            "subject_token_type",
            "token_url",
        ]
        # Compute missing fields
        missing_fields = set(required_fields) - set(
            external_account_info.keys()
        )
        if missing_fields:
            raise ValueError(
                f"GCP external account credentials JSON is missing required "
                f'fields: {", ".join(list(missing_fields))}'
            )

        if external_account_info["type"] != "external_account":
            raise ValueError(
                "The JSON does not contain GCP external account credentials. "
                f'The "type" field is set to {external_account_info["type"]} '
                "instead of 'external_account'."
            )

        return value
validate_external_account_json(value) classmethod

Validate the external account credentials JSON.

Parameters:

Name Type Description Default
value Annotated[pydantic.types.SecretStr, PlainSerializer(func=<function <lambda> at 0x7fa8b730b9c0>, return_type=PydanticUndefined, when_used='json')]

The external account credentials JSON.

required

Returns:

Type Description
Annotated[pydantic.types.SecretStr, PlainSerializer(func=<function <lambda> at 0x7fa8b730b9c0>, return_type=PydanticUndefined, when_used='json')]

The validated external account credentials JSON.

Exceptions:

Type Description
ValueError

If the external account credentials JSON is invalid.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
@field_validator("external_account_json")
@classmethod
def validate_external_account_json(
    cls, value: PlainSerializedSecretStr
) -> PlainSerializedSecretStr:
    """Validate the external account credentials JSON.

    Args:
        value: The external account credentials JSON.

    Returns:
        The validated external account credentials JSON.

    Raises:
        ValueError: If the external account credentials JSON is invalid.
    """
    try:
        external_account_info = json.loads(value.get_secret_value())
    except json.JSONDecodeError as e:
        raise ValueError(
            f"GCP external account credentials is not a valid JSON: {e}"
        )

    # Check that all fields are present
    required_fields = [
        "type",
        "subject_token_type",
        "token_url",
    ]
    # Compute missing fields
    missing_fields = set(required_fields) - set(
        external_account_info.keys()
    )
    if missing_fields:
        raise ValueError(
            f"GCP external account credentials JSON is missing required "
            f'fields: {", ".join(list(missing_fields))}'
        )

    if external_account_info["type"] != "external_account":
        raise ValueError(
            "The JSON does not contain GCP external account credentials. "
            f'The "type" field is set to {external_account_info["type"]} '
            "instead of 'external_account'."
        )

    return value
validate_service_account_dict(data, validation_info) classmethod

Wrapper method to handle the raw data.

Parameters:

Name Type Description Default
cls

the class handler

required
data Any

the raw input data

required
validation_info ValidationInfo

the context of the validation.

required

Returns:

Type Description
Any

the validated data

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def before_validator(
    cls: Type[BaseModel], data: Any, validation_info: ValidationInfo
) -> Any:
    """Wrapper method to handle the raw data.

    Args:
        cls: the class handler
        data: the raw input data
        validation_info: the context of the validation.

    Returns:
        the validated data
    """
    data = model_validator_data_handler(
        raw_data=data, base_class=cls, validation_info=validation_info
    )
    return method(cls=cls, data=data)
GCPOAuth2Token (AuthenticationConfig)

GCP OAuth 2.0 token credentials.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPOAuth2Token(AuthenticationConfig):
    """GCP OAuth 2.0 token credentials."""

    token: PlainSerializedSecretStr = Field(
        title="GCP OAuth 2.0 Token",
    )
GCPOAuth2TokenConfig (GCPBaseProjectIDConfig, GCPOAuth2Token)

GCP OAuth 2.0 configuration.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPOAuth2TokenConfig(GCPBaseProjectIDConfig, GCPOAuth2Token):
    """GCP OAuth 2.0 configuration."""

    service_account_email: Optional[str] = Field(
        default=None,
        title="GCP Service Account Email",
        description="The email address of the service account that signed the "
        "token. If not provided, the token is assumed to be issued for a user "
        "account.",
    )
GCPServiceAccountConfig (GCPBaseConfig, GCPServiceAccountCredentials)

GCP service account configuration.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials):
    """GCP service account configuration."""

    _project_id: Optional[str] = None

    @property
    def gcp_project_id(self) -> str:
        """Get the GCP project ID.

        When a service account JSON is provided, the project ID can be extracted
        from it instead of being provided explicitly.

        Returns:
            The GCP project ID.
        """
        if self._project_id is None:
            self._project_id = json.loads(
                self.service_account_json.get_secret_value()
            )["project_id"]
            # Guaranteed by the field validator
            assert self._project_id is not None

        return self._project_id
gcp_project_id: str property readonly

Get the GCP project ID.

When a service account JSON is provided, the project ID can be extracted from it instead of being provided explicitly.

Returns:

Type Description
str

The GCP project ID.

model_post_init(/, self, context)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Parameters:

Name Type Description Default
self BaseModel

The BaseModel instance.

required
context Any

The context.

required
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
    """This function is meant to behave like a BaseModel method to initialise private attributes.

    It takes context as an argument since that's what pydantic-core passes when calling it.

    Args:
        self: The BaseModel instance.
        context: The context.
    """
    if getattr(self, '__pydantic_private__', None) is None:
        pydantic_private = {}
        for name, private_attr in self.__private_attributes__.items():
            default = private_attr.get_default()
            if default is not PydanticUndefined:
                pydantic_private[name] = default
        object_setattr(self, '__pydantic_private__', pydantic_private)
GCPServiceAccountCredentials (AuthenticationConfig)

GCP service account credentials.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceAccountCredentials(AuthenticationConfig):
    """GCP service account credentials."""

    service_account_json: PlainSerializedSecretStr = Field(
        title="GCP Service Account Key JSON optionally base64 encoded.",
    )

    generate_temporary_tokens: bool = Field(
        default=True,
        title="Generate temporary OAuth 2.0 tokens",
        description="Whether to generate temporary OAuth 2.0 tokens from the "
        "service account key JSON. If set to False, the connector will "
        "distribute the service account key JSON to clients instead.",
    )

    @model_validator(mode="before")
    @classmethod
    @before_validator_handler
    def validate_service_account_dict(
        cls, data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Convert the service account credentials to JSON if given in dict format.

        Args:
            data: The configuration values.

        Returns:
            The validated configuration values.

        Raises:
            ValueError: If the service account credentials JSON is invalid.
        """
        service_account_json = data.get("service_account_json")
        if isinstance(service_account_json, dict):
            data["service_account_json"] = json.dumps(
                data["service_account_json"]
            )
        elif isinstance(service_account_json, str):
            # Check if the service account JSON is base64 encoded and decode it
            if re.match(r"^[A-Za-z0-9+/=]+$", service_account_json):
                try:
                    data["service_account_json"] = base64.b64decode(
                        service_account_json
                    ).decode("utf-8")
                except Exception as e:
                    raise ValueError(
                        f"Failed to decode base64 encoded service account JSON: {e}"
                    )

        return data

    @field_validator("service_account_json")
    @classmethod
    def validate_service_account_json(
        cls, value: PlainSerializedSecretStr
    ) -> PlainSerializedSecretStr:
        """Validate the service account credentials JSON.

        Args:
            value: The service account credentials JSON.

        Returns:
            The validated service account credentials JSON.

        Raises:
            ValueError: If the service account credentials JSON is invalid.
        """
        try:
            service_account_info = json.loads(value.get_secret_value())
        except json.JSONDecodeError as e:
            raise ValueError(
                f"GCP service account credentials is not a valid JSON: {e}"
            )

        # Check that all fields are present
        required_fields = [
            "type",
            "project_id",
            "private_key_id",
            "private_key",
            "client_email",
            "client_id",
            "auth_uri",
            "token_uri",
            "auth_provider_x509_cert_url",
            "client_x509_cert_url",
        ]
        # Compute missing fields
        missing_fields = set(required_fields) - set(
            service_account_info.keys()
        )
        if missing_fields:
            raise ValueError(
                f"GCP service account credentials JSON is missing required "
                f'fields: {", ".join(list(missing_fields))}'
            )

        if service_account_info["type"] != "service_account":
            raise ValueError(
                "The JSON does not contain GCP service account credentials. "
                f'The "type" field is set to {service_account_info["type"]} '
                "instead of 'service_account'."
            )

        return value
validate_service_account_dict(data, validation_info) classmethod

Wrapper method to handle the raw data.

Parameters:

Name Type Description Default
cls

the class handler

required
data Any

the raw input data

required
validation_info ValidationInfo

the context of the validation.

required

Returns:

Type Description
Any

the validated data

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def before_validator(
    cls: Type[BaseModel], data: Any, validation_info: ValidationInfo
) -> Any:
    """Wrapper method to handle the raw data.

    Args:
        cls: the class handler
        data: the raw input data
        validation_info: the context of the validation.

    Returns:
        the validated data
    """
    data = model_validator_data_handler(
        raw_data=data, base_class=cls, validation_info=validation_info
    )
    return method(cls=cls, data=data)
validate_service_account_json(value) classmethod

Validate the service account credentials JSON.

Parameters:

Name Type Description Default
value Annotated[pydantic.types.SecretStr, PlainSerializer(func=<function <lambda> at 0x7fa8b730b9c0>, return_type=PydanticUndefined, when_used='json')]

The service account credentials JSON.

required

Returns:

Type Description
Annotated[pydantic.types.SecretStr, PlainSerializer(func=<function <lambda> at 0x7fa8b730b9c0>, return_type=PydanticUndefined, when_used='json')]

The validated service account credentials JSON.

Exceptions:

Type Description
ValueError

If the service account credentials JSON is invalid.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
@field_validator("service_account_json")
@classmethod
def validate_service_account_json(
    cls, value: PlainSerializedSecretStr
) -> PlainSerializedSecretStr:
    """Validate the service account credentials JSON.

    Args:
        value: The service account credentials JSON.

    Returns:
        The validated service account credentials JSON.

    Raises:
        ValueError: If the service account credentials JSON is invalid.
    """
    try:
        service_account_info = json.loads(value.get_secret_value())
    except json.JSONDecodeError as e:
        raise ValueError(
            f"GCP service account credentials is not a valid JSON: {e}"
        )

    # Check that all fields are present
    required_fields = [
        "type",
        "project_id",
        "private_key_id",
        "private_key",
        "client_email",
        "client_id",
        "auth_uri",
        "token_uri",
        "auth_provider_x509_cert_url",
        "client_x509_cert_url",
    ]
    # Compute missing fields
    missing_fields = set(required_fields) - set(
        service_account_info.keys()
    )
    if missing_fields:
        raise ValueError(
            f"GCP service account credentials JSON is missing required "
            f'fields: {", ".join(list(missing_fields))}'
        )

    if service_account_info["type"] != "service_account":
        raise ValueError(
            "The JSON does not contain GCP service account credentials. "
            f'The "type" field is set to {service_account_info["type"]} '
            "instead of 'service_account'."
        )

    return value
GCPServiceAccountImpersonationConfig (GCPServiceAccountConfig)

GCP service account impersonation configuration.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceAccountImpersonationConfig(GCPServiceAccountConfig):
    """GCP service account impersonation configuration."""

    target_principal: str = Field(
        title="GCP Service Account Email to impersonate",
    )
model_post_init(/, self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
    """We need to both initialize private attributes and call the user-defined model_post_init
    method.
    """
    init_private_attributes(self, context)
    original_model_post_init(self, context)
GCPServiceConnector (ServiceConnector)

GCP service connector.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPServiceConnector(ServiceConnector):
    """GCP service connector."""

    config: GCPBaseConfig

    _session_cache: Dict[
        Tuple[str, Optional[str], Optional[str]],
        Tuple[
            gcp_credentials.Credentials,
            Optional[datetime.datetime],
        ],
    ] = {}

    @classmethod
    def _get_connector_type(cls) -> ServiceConnectorTypeModel:
        """Get the service connector type specification.

        Returns:
            The service connector type specification.
        """
        return GCP_SERVICE_CONNECTOR_TYPE_SPEC

    def get_session(
        self,
        auth_method: str,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> Tuple[gcp_credentials.Credentials, Optional[datetime.datetime]]:
        """Get a GCP session object with credentials for the specified resource.

        Args:
            auth_method: The authentication method to use.
            resource_type: The resource type to get credentials for.
            resource_id: The resource ID to get credentials for.

        Returns:
            GCP session with credentials for the specified resource and its
            expiration timestamp, if applicable.
        """
        # We maintain a cache of all sessions to avoid re-authenticating
        # multiple times for the same resource
        key = (auth_method, resource_type, resource_id)
        if key in self._session_cache:
            session, expires_at = self._session_cache[key]
            if expires_at is None:
                return session, None

            # Refresh expired sessions
            now = datetime.datetime.now(datetime.timezone.utc)
            expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
            # check if the token expires in the near future
            if expires_at > now + datetime.timedelta(
                minutes=GCP_SESSION_EXPIRATION_BUFFER
            ):
                return session, expires_at

        logger.debug(
            f"Creating GCP authentication session for auth method "
            f"'{auth_method}', resource type '{resource_type}' and resource ID "
            f"'{resource_id}'..."
        )
        session, expires_at = self._authenticate(
            auth_method, resource_type, resource_id
        )
        self._session_cache[key] = (session, expires_at)
        return session, expires_at

    @classmethod
    def _get_scopes(
        cls,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> List[str]:
        """Get the OAuth 2.0 scopes to use for the specified resource type.

        Args:
            resource_type: The resource type to get scopes for.
            resource_id: The resource ID to get scopes for.

        Returns:
            OAuth 2.0 scopes to use for the specified resource type.
        """
        return [
            "https://www.googleapis.com/auth/cloud-platform",
        ]

    def _authenticate(
        self,
        auth_method: str,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> Tuple[
        gcp_credentials.Credentials,
        Optional[datetime.datetime],
    ]:
        """Authenticate to GCP and return a session with credentials.

        Args:
            auth_method: The authentication method to use.
            resource_type: The resource type to authenticate for.
            resource_id: The resource ID to authenticate for.

        Returns:
            GCP OAuth 2.0 credentials and their expiration time if applicable.

        Raises:
            AuthorizationException: If the authentication fails.
        """
        cfg = self.config
        scopes = self._get_scopes(resource_type, resource_id)
        expires_at: Optional[datetime.datetime] = None
        if auth_method == GCPAuthenticationMethods.IMPLICIT:
            self._check_implicit_auth_method_allowed()

            # Determine the credentials from the environment
            # Override the project ID if specified in the config
            credentials, project_id = google.auth.default(
                scopes=scopes,
            )

        elif auth_method == GCPAuthenticationMethods.OAUTH2_TOKEN:
            assert isinstance(cfg, GCPOAuth2TokenConfig)

            expires_at = self.expires_at
            if expires_at:
                # Remove the UTC timezone
                expires_at = expires_at.replace(tzinfo=None)

            credentials = gcp_credentials.Credentials(
                token=cfg.token.get_secret_value(),
                expiry=expires_at,
                scopes=scopes,
            )

            if cfg.service_account_email:
                credentials.signer_email = cfg.service_account_email
        else:
            if auth_method == GCPAuthenticationMethods.USER_ACCOUNT:
                assert isinstance(cfg, GCPUserAccountConfig)
                credentials = (
                    gcp_credentials.Credentials.from_authorized_user_info(
                        json.loads(cfg.user_account_json.get_secret_value()),
                        scopes=scopes,
                    )
                )
            elif auth_method == GCPAuthenticationMethods.EXTERNAL_ACCOUNT:
                self._check_implicit_auth_method_allowed()

                assert isinstance(cfg, GCPExternalAccountConfig)

                # As a special case, for the AWS external account credential,
                # we use a custom credential class that supports extracting
                # the AWS credentials from the local environment, metadata
                # service or IRSA (if running on AWS EKS).
                account_info = json.loads(
                    cfg.external_account_json.get_secret_value()
                )
                if (
                    account_info.get("subject_token_type")
                    == _AWS_SUBJECT_TOKEN_TYPE
                ):
                    if ZenMLAwsSecurityCredentialsSupplier is not None:
                        account_info["aws_security_credentials_supplier"] = (
                            ZenMLAwsSecurityCredentialsSupplier(
                                account_info.pop("credential_source"),
                            )
                        )
                    credentials = (
                        ZenMLGCPAWSExternalAccountCredentials.from_info(
                            account_info,
                            scopes=scopes,
                        )
                    )
                else:
                    credentials, _ = _get_external_account_credentials(
                        json.loads(
                            cfg.external_account_json.get_secret_value()
                        ),
                        filename="",  # Not used
                        scopes=scopes,
                    )

            else:
                # Service account or impersonation (which is a special case of
                # service account authentication)

                assert isinstance(cfg, GCPServiceAccountConfig)

                credentials = (
                    gcp_service_account.Credentials.from_service_account_info(
                        json.loads(
                            cfg.service_account_json.get_secret_value()
                        ),
                        scopes=scopes,
                    )
                )

                if auth_method == GCPAuthenticationMethods.IMPERSONATION:
                    assert isinstance(
                        cfg, GCPServiceAccountImpersonationConfig
                    )

                    try:
                        credentials = gcp_impersonated_credentials.Credentials(
                            source_credentials=credentials,
                            target_principal=cfg.target_principal,
                            target_scopes=scopes,
                            lifetime=self.expiration_seconds,
                        )
                    except google.auth.exceptions.GoogleAuthError as e:
                        raise AuthorizationException(
                            f"Failed to impersonate service account "
                            f"'{cfg.target_principal}': {e}"
                        )

        if not credentials.valid:
            try:
                with requests.Session() as session:
                    req = Request(session)
                    credentials.refresh(req)
            except google.auth.exceptions.GoogleAuthError as e:
                raise AuthorizationException(
                    f"Could not fetch GCP OAuth2 token: {e}"
                )

        if credentials.expiry:
            # Add the UTC timezone to the expiration time
            expires_at = credentials.expiry.replace(
                tzinfo=datetime.timezone.utc
            )

        return credentials, expires_at

    def _parse_gcs_resource_id(self, resource_id: str) -> str:
        """Validate and convert an GCS resource ID to an GCS bucket name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The GCS bucket name.

        Raises:
            ValueError: If the provided resource ID is not a valid GCS bucket
                name or URI.
        """
        # The resource ID could mean different things:
        #
        # - an GCS bucket URI
        # - the GCS bucket name
        #
        # We need to extract the bucket name from the provided resource ID
        bucket_name: Optional[str] = None
        if re.match(
            r"^gs://[a-z0-9][a-z0-9_-]{1,61}[a-z0-9](/.*)*$",
            resource_id,
        ):
            # The resource ID is an GCS bucket URI
            bucket_name = resource_id.split("/")[2]
        elif re.match(
            r"^[a-z0-9][a-z0-9_-]{1,61}[a-z0-9]$",
            resource_id,
        ):
            # The resource ID is the GCS bucket name
            bucket_name = resource_id
        else:
            raise ValueError(
                f"Invalid resource ID for an GCS bucket: {resource_id}. "
                f"Supported formats are:\n"
                f"GCS bucket URI: gs://<bucket-name>\n"
                f"GCS bucket name: <bucket-name>"
            )

        return bucket_name

    def _parse_gar_resource_id(
        self,
        resource_id: str,
    ) -> Tuple[str, Optional[str]]:
        """Validate and convert a GAR resource ID to a Google Artifact Registry ID and name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The Google Artifact Registry ID and name. The name is omitted if the
            resource ID is a GCR repository URI.

        Raises:
            ValueError: If the provided resource ID is not a valid GAR
                or GCR repository URI.
        """
        # The resource ID could mean different things:
        #
        # - a GAR repository URI
        # - a GAR repository name
        # - a GCR repository URI (backwards-compatibility)
        #
        # We need to extract the project ID and registry ID from
        # the provided resource ID
        config_project_id = self.config.gcp_project_id
        project_id: Optional[str] = None
        canonical_url: str
        registry_name: Optional[str] = None

        # A Google Artifact Registry URI uses the <location>-docker-pkg.dev
        # domain format with the project ID as the first part of the URL path
        # and the registry name as the second part of the URL path
        if match := re.match(
            r"^(https://)?(([a-z0-9-]+)-docker\.pkg\.dev/([a-z0-9-]+)/([a-z0-9-.]+))(/.+)*$",
            resource_id,
        ):
            # The resource ID is a Google Artifact Registry URI
            project_id = match[4]
            location = match[3]
            repository = match[5]

            # Return the GAR URL without the image name and without the protocol
            canonical_url = match[2]
            registry_name = f"projects/{project_id}/locations/{location}/repositories/{repository}"

        # Alternatively, the Google Artifact Registry name uses the
        # projects/<project-id>/locations/<location>/repositories/<repository-id>
        # format
        elif match := re.match(
            r"^projects/([a-z0-9-]+)/locations/([a-z0-9-]+)/repositories/([a-z0-9-.]+)$",
            resource_id,
        ):
            # The resource ID is a Google Artifact Registry name
            project_id = match[1]
            location = match[2]
            repository = match[3]

            # Return the GAR URL
            canonical_url = (
                f"{location}-docker.pkg.dev/{project_id}/{repository}"
            )
            registry_name = resource_id

        # A legacy GCR repository URI uses one of several hostnames (gcr.io,
        # us.gcr.io, eu.gcr.io, asia.gcr.io) and the project ID is the
        # first part of the URL path
        elif match := re.match(
            r"^(https://)?(((us|eu|asia)\.)?gcr\.io/[a-z0-9-]+)(/.+)*$",
            resource_id,
        ):
            # The resource ID is a legacy GCR repository URI.
            # Return the GAR URL without the image name and without the protocol
            canonical_url = match[2]

        else:
            raise ValueError(
                f"Invalid resource ID for a Google Artifact Registry: "
                f"{resource_id}. Supported formats are:\n"
                f"Google Artifact Registry URI: [https://]<region>-docker.pkg.dev/<project-id>/<registry-id>[/<repository-name>]\n"
                f"Google Artifact Registry name: projects/<project-id>/locations/<location>/repositories/<repository-id>\n"
                f"GCR repository URI: [https://][us.|eu.|asia.]gcr.io/<project-id>[/<repository-name>]"
            )

        # If the connector is configured with a project and the resource ID
        # is a GAR repository URI that specifies a different project,
        # we raise an error
        if project_id and project_id != config_project_id:
            raise ValueError(
                f"The GCP project for the {resource_id} Google Artifact "
                f"Registry '{project_id}' does not match the project "
                f"configured in the connector: '{config_project_id}'."
            )

        return canonical_url, registry_name

    def _parse_gke_resource_id(self, resource_id: str) -> str:
        """Validate and convert an GKE resource ID to a GKE cluster name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The GKE cluster name.

        Raises:
            ValueError: If the provided resource ID is not a valid GKE cluster
                name.
        """
        if re.match(
            r"^[a-z0-9]+[a-z0-9_-]*$",
            resource_id,
        ):
            # Assume the resource ID is an GKE cluster name
            cluster_name = resource_id
        else:
            raise ValueError(
                f"Invalid resource ID for a GKE cluster: {resource_id}. "
                f"Supported formats are:\n"
                f"GKE cluster name: <cluster-name>"
            )

        return cluster_name

    def _canonical_resource_id(
        self, resource_type: str, resource_id: str
    ) -> str:
        """Convert a resource ID to its canonical form.

        Args:
            resource_type: The resource type to canonicalize.
            resource_id: The resource ID to canonicalize.

        Returns:
            The canonical resource ID.
        """
        if resource_type == GCS_RESOURCE_TYPE:
            bucket = self._parse_gcs_resource_id(resource_id)
            return f"gs://{bucket}"
        elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            cluster_name = self._parse_gke_resource_id(resource_id)
            return cluster_name
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            registry_id, _ = self._parse_gar_resource_id(
                resource_id,
            )
            return registry_id
        else:
            return resource_id

    def _get_default_resource_id(self, resource_type: str) -> str:
        """Get the default resource ID for a resource type.

        Args:
            resource_type: The type of the resource to get a default resource ID
                for. Only called with resource types that do not support
                multiple instances.

        Returns:
            The default resource ID for the resource type.

        Raises:
            RuntimeError: If the GCR registry ID (GCP account ID)
                cannot be retrieved from GCP because the connector is not
                authorized.
        """
        if resource_type == GCP_RESOURCE_TYPE:
            return self.config.gcp_project_id

        raise RuntimeError(
            f"Default resource ID not supported for '{resource_type}' resource "
            "type."
        )

    def _connect_to_resource(
        self,
        **kwargs: Any,
    ) -> Any:
        """Authenticate and connect to a GCP resource.

        Initialize and return a session or client object depending on the
        connector configuration:

        - initialize and return generic google-auth credentials if the resource
        type is a generic GCP resource
        - initialize and return a google-storage client for an GCS resource type

        For the Docker and Kubernetes resource types, the connector does not
        support connecting to the resource directly. Instead, the connector
        supports generating a connector client object for the resource type
        in question.

        Args:
            kwargs: Additional implementation specific keyword arguments to pass
                to the session or client constructor.

        Returns:
            Generic GCP credentials for GCP generic resources and a
            google-storage GCS client for GCS resources.

        Raises:
            NotImplementedError: If the connector instance does not support
                directly connecting to the indicated resource type.
        """
        resource_type = self.resource_type
        resource_id = self.resource_id

        assert resource_type is not None
        assert resource_id is not None

        # Regardless of the resource type, we must authenticate to GCP first
        # before we can connect to any GCP resource
        credentials, _ = self.get_session(
            self.auth_method,
            resource_type=resource_type,
            resource_id=resource_id,
        )

        if resource_type == GCS_RESOURCE_TYPE:
            # Validate that the resource ID is a valid GCS bucket name
            self._parse_gcs_resource_id(resource_id)

            # Create an GCS client for the bucket
            client = storage.Client(
                project=self.config.gcp_project_id, credentials=credentials
            )
            return client

        if resource_type == GCP_RESOURCE_TYPE:
            return credentials

        raise NotImplementedError(
            f"Connecting to {resource_type} resources is not directly "
            "supported by the GCP connector. Please call the "
            f"`get_connector_client` method to get a {resource_type} connector "
            "instance for the resource."
        )

    def _configure_local_client(
        self,
        **kwargs: Any,
    ) -> None:
        """Configure a local client to authenticate and connect to a resource.

        This method uses the connector's configuration to configure a local
        client or SDK installed on the localhost for the indicated resource.

        Args:
            kwargs: Additional implementation specific keyword arguments to use
                to configure the client.

        Raises:
            NotImplementedError: If the connector instance does not support
                local configuration for the configured resource type or
                authentication method.registry
            AuthorizationException: If the local client configuration fails.
        """
        resource_type = self.resource_type

        if resource_type in [GCP_RESOURCE_TYPE, GCS_RESOURCE_TYPE]:
            gcloud_config_json: Optional[str] = None

            # There is no way to configure the local gcloud CLI to use
            # temporary OAuth 2.0 tokens. However, we can configure it to use
            # the service account or external account credentials
            if self.auth_method == GCPAuthenticationMethods.SERVICE_ACCOUNT:
                assert isinstance(self.config, GCPServiceAccountConfig)
                # Use the service account credentials JSON to configure the
                # local gcloud CLI
                gcloud_config_json = (
                    self.config.service_account_json.get_secret_value()
                )
            elif self.auth_method == GCPAuthenticationMethods.EXTERNAL_ACCOUNT:
                assert isinstance(self.config, GCPExternalAccountConfig)
                # Use the external account credentials JSON to configure the
                # local gcloud CLI
                gcloud_config_json = (
                    self.config.external_account_json.get_secret_value()
                )

            if gcloud_config_json:
                from google.auth import _cloud_sdk

                if not shutil.which("gcloud"):
                    raise AuthorizationException(
                        "The local gcloud CLI is not installed. Please "
                        "install the gcloud CLI to use this feature."
                    )

                # Write the credentials JSON to a temporary file
                with tempfile.NamedTemporaryFile(
                    mode="w", suffix=".json", delete=True
                ) as f:
                    f.write(gcloud_config_json)
                    f.flush()
                    adc_path = f.name

                    try:
                        # Run the gcloud CLI command to configure the local
                        # gcloud CLI to use the credentials JSON
                        subprocess.run(
                            [
                                "gcloud",
                                "auth",
                                "login",
                                "--quiet",
                                "--cred-file",
                                adc_path,
                            ],
                            check=True,
                            stderr=subprocess.STDOUT,
                            encoding="utf-8",
                            stdout=subprocess.PIPE,
                        )
                    except subprocess.CalledProcessError as e:
                        raise AuthorizationException(
                            f"Failed to configure the local gcloud CLI to use "
                            f"the credentials JSON: {e}\n"
                            f"{e.stdout.decode()}"
                        )

                try:
                    # Run the gcloud CLI command to configure the local gcloud
                    # CLI to use the credentials project ID
                    subprocess.run(
                        [
                            "gcloud",
                            "config",
                            "set",
                            "project",
                            self.config.gcp_project_id,
                        ],
                        check=True,
                        stderr=subprocess.STDOUT,
                        stdout=subprocess.PIPE,
                    )
                except subprocess.CalledProcessError as e:
                    raise AuthorizationException(
                        f"Failed to configure the local gcloud CLI to use "
                        f"the project ID: {e}\n"
                        f"{e.stdout.decode()}"
                    )

                # Dump the service account credentials JSON to
                # the local gcloud application default credentials file
                adc_path = (
                    _cloud_sdk.get_application_default_credentials_path()
                )
                with open(adc_path, "w") as f:
                    f.write(gcloud_config_json)

                logger.info(
                    "Updated the local gcloud CLI and application default "
                    f"credentials file ({adc_path})."
                )

                return

            raise NotImplementedError(
                f"Local gcloud client configuration for resource type "
                f"{resource_type} is only supported if the "
                f"'{GCPAuthenticationMethods.SERVICE_ACCOUNT}' or "
                f"'{GCPAuthenticationMethods.EXTERNAL_ACCOUNT}' "
                f"authentication method is used and only if the generation of "
                f"temporary OAuth 2.0 tokens is disabled by setting the "
                f"'generate_temporary_tokens' option to 'False' in the "
                f"service connector configuration."
            )

        raise NotImplementedError(
            f"Configuring the local client for {resource_type} resources is "
            "not directly supported by the GCP connector. Please call the "
            f"`get_connector_client` method to get a {resource_type} connector "
            "instance for the resource."
        )

    @classmethod
    def _auto_configure(
        cls,
        auth_method: Optional[str] = None,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
        **kwargs: Any,
    ) -> "GCPServiceConnector":
        """Auto-configure the connector.

        Instantiate a GCP connector with a configuration extracted from the
        authentication configuration available in the environment (e.g.
        environment variables or local GCP client/SDK configuration files).

        Args:
            auth_method: The particular authentication method to use. If not
                specified, the connector implementation must decide which
                authentication method to use or raise an exception.
            resource_type: The type of resource to configure.
            resource_id: The ID of the resource to configure. The
                implementation may choose to either require or ignore this
                parameter if it does not support or detect an resource type that
                supports multiple instances.
            kwargs: Additional implementation specific keyword arguments to use.

        Returns:
            A GCP connector instance configured with authentication credentials
            automatically extracted from the environment.

        Raises:
            NotImplementedError: If the connector implementation does not
                support auto-configuration for the specified authentication
                method.
            AuthorizationException: If no GCP credentials can be loaded from
                the environment.
        """
        auth_config: GCPBaseConfig

        scopes = cls._get_scopes()
        expires_at: Optional[datetime.datetime] = None

        try:
            # Determine the credentials from the environment
            credentials, project_id = google.auth.default(
                scopes=scopes,
            )
        except google.auth.exceptions.GoogleAuthError as e:
            raise AuthorizationException(
                f"No GCP credentials could be detected: {e}"
            )

        if project_id is None:
            raise AuthorizationException(
                "No GCP project ID could be detected. Please set the active "
                "GCP project ID by running 'gcloud config set project'."
            )

        if auth_method == GCPAuthenticationMethods.IMPLICIT:
            auth_config = GCPBaseProjectIDConfig(
                project_id=project_id,
            )
        elif auth_method == GCPAuthenticationMethods.OAUTH2_TOKEN:
            # Refresh the credentials if necessary, to fetch the access token
            if not credentials.valid or not credentials.token:
                try:
                    with requests.Session() as session:
                        req = Request(session)
                        credentials.refresh(req)
                except google.auth.exceptions.GoogleAuthError as e:
                    raise AuthorizationException(
                        f"Could not fetch GCP OAuth2 token: {e}"
                    )

            if not credentials.token:
                raise AuthorizationException(
                    "Could not fetch GCP OAuth2 token"
                )

            auth_config = GCPOAuth2TokenConfig(
                project_id=project_id,
                token=credentials.token,
                service_account_email=credentials.signer_email
                if hasattr(credentials, "signer_email")
                else None,
            )
            if credentials.expiry:
                # Add the UTC timezone to the expiration time
                expires_at = credentials.expiry.replace(
                    tzinfo=datetime.timezone.utc
                )
        else:
            # Check if user account credentials are available
            if isinstance(credentials, gcp_credentials.Credentials):
                if auth_method not in [
                    GCPAuthenticationMethods.USER_ACCOUNT,
                    None,
                ]:
                    raise NotImplementedError(
                        f"Could not perform auto-configuration for "
                        f"authentication method {auth_method}. Only "
                        f"GCP user account credentials have been detected."
                    )
                auth_method = GCPAuthenticationMethods.USER_ACCOUNT
                user_account_json = json.dumps(
                    dict(
                        type="authorized_user",
                        client_id=credentials._client_id,
                        client_secret=credentials._client_secret,
                        refresh_token=credentials.refresh_token,
                    )
                )
                auth_config = GCPUserAccountConfig(
                    project_id=project_id,
                    user_account_json=user_account_json,
                )
            # Check if service account credentials are available
            elif isinstance(credentials, gcp_service_account.Credentials):
                if auth_method not in [
                    GCPAuthenticationMethods.SERVICE_ACCOUNT,
                    None,
                ]:
                    raise NotImplementedError(
                        f"Could not perform auto-configuration for "
                        f"authentication method {auth_method}. Only "
                        f"GCP service account credentials have been detected."
                    )

                auth_method = GCPAuthenticationMethods.SERVICE_ACCOUNT
                service_account_json_file = os.environ.get(
                    "GOOGLE_APPLICATION_CREDENTIALS"
                )
                if service_account_json_file is None:
                    # No explicit service account JSON file was specified in the
                    # environment, meaning that the credentials were loaded from
                    # the GCP application default credentials (ADC) file.
                    from google.auth import _cloud_sdk

                    # Use the location of the gcloud application default
                    # credentials file
                    service_account_json_file = (
                        _cloud_sdk.get_application_default_credentials_path()
                    )

                if not service_account_json_file or not os.path.isfile(
                    service_account_json_file
                ):
                    raise AuthorizationException(
                        "No GCP service account credentials were found in the "
                        "environment or the application default credentials "
                        "path. Please set the GOOGLE_APPLICATION_CREDENTIALS "
                        "environment variable to the path of the service "
                        "account JSON file or run 'gcloud auth application-"
                        "default login' to generate a new ADC file."
                    )
                with open(service_account_json_file, "r") as f:
                    service_account_json = f.read()
                auth_config = GCPServiceAccountConfig(
                    project_id=project_id,
                    service_account_json=service_account_json,
                )
            # Check if external account credentials are available
            elif isinstance(credentials, gcp_external_account.Credentials):
                if auth_method not in [
                    GCPAuthenticationMethods.EXTERNAL_ACCOUNT,
                    None,
                ]:
                    raise NotImplementedError(
                        f"Could not perform auto-configuration for "
                        f"authentication method {auth_method}. Only "
                        f"GCP external account credentials have been detected."
                    )

                auth_method = GCPAuthenticationMethods.EXTERNAL_ACCOUNT
                external_account_json_file = os.environ.get(
                    "GOOGLE_APPLICATION_CREDENTIALS"
                )
                if external_account_json_file is None:
                    # No explicit service account JSON file was specified in the
                    # environment, meaning that the credentials were loaded from
                    # the GCP application default credentials (ADC) file.
                    from google.auth import _cloud_sdk

                    # Use the location of the gcloud application default
                    # credentials file
                    external_account_json_file = (
                        _cloud_sdk.get_application_default_credentials_path()
                    )

                if not external_account_json_file or not os.path.isfile(
                    external_account_json_file
                ):
                    raise AuthorizationException(
                        "No GCP service account credentials were found in the "
                        "environment or the application default credentials "
                        "path. Please set the GOOGLE_APPLICATION_CREDENTIALS "
                        "environment variable to the path of the external "
                        "account JSON file or run 'gcloud auth application-"
                        "default login' to generate a new ADC file."
                    )
                with open(external_account_json_file, "r") as f:
                    external_account_json = f.read()
                auth_config = GCPExternalAccountConfig(
                    project_id=project_id,
                    external_account_json=external_account_json,
                )
            else:
                raise AuthorizationException(
                    "No valid GCP credentials could be detected."
                )

        return cls(
            auth_method=auth_method,
            resource_type=resource_type,
            resource_id=resource_id
            if resource_type not in [GCP_RESOURCE_TYPE, None]
            else None,
            expires_at=expires_at,
            config=auth_config,
        )

    def _verify(
        self,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> List[str]:
        """Verify and list all the resources that the connector can access.

        Args:
            resource_type: The type of the resource to verify. If omitted and
                if the connector supports multiple resource types, the
                implementation must verify that it can authenticate and connect
                to any and all of the supported resource types.
            resource_id: The ID of the resource to connect to. Omitted if a
                resource type is not specified. It has the same value as the
                default resource ID if the supplied resource type doesn't
                support multiple instances. If the supplied resource type does
                allows multiple instances, this parameter may still be omitted
                to fetch a list of resource IDs identifying all the resources
                of the indicated type that the connector can access.

        Returns:
            The list of resources IDs in canonical format identifying the
            resources that the connector can access. This list is empty only
            if the resource type is not specified (i.e. for multi-type
            connectors).

        Raises:
            AuthorizationException: If the connector cannot authenticate or
                access the specified resource.
        """
        # If the resource type is not specified, treat this the
        # same as a generic GCP connector.
        credentials, _ = self.get_session(
            self.auth_method,
            resource_type=resource_type or GCP_RESOURCE_TYPE,
            resource_id=resource_id,
        )

        if not resource_type:
            return []

        if resource_type == GCP_RESOURCE_TYPE:
            assert resource_id is not None
            return [resource_id]

        if resource_type == GCS_RESOURCE_TYPE:
            gcs_client = storage.Client(
                project=self.config.gcp_project_id, credentials=credentials
            )
            if not resource_id:
                # List all GCS buckets
                try:
                    buckets = gcs_client.list_buckets()
                    bucket_names = [bucket.name for bucket in buckets]
                except google.api_core.exceptions.GoogleAPIError as e:
                    msg = f"failed to list GCS buckets: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return [f"gs://{bucket}" for bucket in bucket_names]
            else:
                # Check if the specified GCS bucket exists
                bucket_name = self._parse_gcs_resource_id(resource_id)
                try:
                    gcs_client.get_bucket(bucket_name)
                    return [resource_id]
                except google.api_core.exceptions.GoogleAPIError as e:
                    msg = f"failed to fetch GCS bucket {bucket_name}: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

        if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            # Get a GAR client
            gar_client = artifactregistry_v1.ArtifactRegistryClient(
                credentials=credentials
            )

            if resource_id:
                registry_id, registry_name = self._parse_gar_resource_id(
                    resource_id
                )

                if registry_name is None:
                    # This is a legacy GCR repository URI. We can't verify
                    # the repository access without attempting to connect to it
                    # via Docker/OCI, so just return the resource ID.
                    return [registry_id]

                # Check if the specified GAR registry exists
                try:
                    repository = gar_client.get_repository(
                        name=registry_name,
                    )
                    if repository.format_.name != "DOCKER":
                        raise AuthorizationException(
                            f"Google Artifact Registry '{resource_id}' is not a "
                            "Docker registry."
                        )
                    return [registry_id]
                except google.api_core.exceptions.GoogleAPIError as e:
                    msg = f"Failed to fetch Google Artifact Registry '{registry_id}': {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

            # For backwards compatibility, we initialize the list of resource
            # IDs with all GCR supported registries for the configured GCP
            # project
            resource_ids: List[str] = [
                f"{location}gcr.io/{self.config.gcp_project_id}"
                for location in ["", "us.", "eu.", "asia."]
            ]

            # List all Google Artifact Registries
            try:
                # First, we need to fetch all the Artifact Registry supported
                # locations
                locations = gar_client.list_locations(
                    request=locations_pb2.ListLocationsRequest(
                        name=f"projects/{self.config.gcp_project_id}"
                    )
                )
                location_names = [
                    locations.locations[i].location_id
                    for i in range(len(locations.locations))
                ]

                # Then, we need to fetch all the repositories in each location
                repository_names: List[str] = []
                for location in location_names:
                    repositories = gar_client.list_repositories(
                        parent=f"projects/{self.config.gcp_project_id}/locations/{location}"
                    )
                    repository_names.extend(
                        [
                            repository.name
                            for repository in repositories
                            if repository.format_.name == "DOCKER"
                        ]
                    )

                for repository_name in repository_names:
                    # Convert the repository name to a canonical GAR URL
                    resource_ids.append(
                        self._parse_gar_resource_id(repository_name)[0]
                    )

            except google.api_core.exceptions.GoogleAPIError as e:
                msg = f"Failed to list Google Artifact Registries: {e}"
                logger.error(msg)
                # TODO: enable when GCR is no longer supported:
                # raise AuthorizationException(msg) from e

            return resource_ids

        if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            gke_client = container_v1.ClusterManagerClient(
                credentials=credentials
            )

            # List all GKE clusters
            try:
                clusters = gke_client.list_clusters(
                    parent=f"projects/{self.config.gcp_project_id}/locations/-"
                )
                cluster_names = [cluster.name for cluster in clusters.clusters]
            except google.api_core.exceptions.GoogleAPIError as e:
                msg = f"Failed to list GKE clusters: {e}"
                logger.error(msg)
                raise AuthorizationException(msg) from e

            if not resource_id:
                return cluster_names
            else:
                # Check if the specified GKE cluster exists
                cluster_name = self._parse_gke_resource_id(resource_id)
                if cluster_name not in cluster_names:
                    raise AuthorizationException(
                        f"GKE cluster '{cluster_name}' not found or not "
                        "accessible."
                    )

                return [resource_id]

        return []

    def _get_connector_client(
        self,
        resource_type: str,
        resource_id: str,
    ) -> "ServiceConnector":
        """Get a connector instance that can be used to connect to a resource.

        This method generates a client-side connector instance that can be used
        to connect to a resource of the given type. The client-side connector
        is configured with temporary GCP credentials extracted from the
        current connector and, depending on resource type, it may also be
        of a different connector type:

        - a Kubernetes connector for Kubernetes clusters
        - a Docker connector for Docker registries

        Args:
            resource_type: The type of the resources to connect to.
            resource_id: The ID of a particular resource to connect to.

        Returns:
            A GCP, Kubernetes or Docker connector instance that can be used to
            connect to the specified resource.

        Raises:
            AuthorizationException: If authentication failed.
            ValueError: If the resource type is not supported.
            RuntimeError: If the Kubernetes connector is not installed and the
                resource type is Kubernetes.
        """
        connector_name = ""
        if self.name:
            connector_name = self.name
        if resource_id:
            connector_name += f" ({resource_type} | {resource_id} client)"
        else:
            connector_name += f" ({resource_type} client)"

        logger.debug(f"Getting connector client for {connector_name}")

        credentials, expires_at = self.get_session(
            self.auth_method,
            resource_type=resource_type,
            resource_id=resource_id,
        )

        if resource_type in [GCP_RESOURCE_TYPE, GCS_RESOURCE_TYPE]:
            # By default, use the token extracted from the google credentials
            # object
            auth_method: str = GCPAuthenticationMethods.OAUTH2_TOKEN
            config: GCPBaseConfig = GCPOAuth2TokenConfig(
                project_id=self.config.gcp_project_id,
                token=credentials.token,
                service_account_email=credentials.signer_email
                if hasattr(credentials, "signer_email")
                else None,
            )

            # If the connector is explicitly configured to not generate
            # temporary tokens, use the original config
            if self.auth_method == GCPAuthenticationMethods.USER_ACCOUNT:
                assert isinstance(self.config, GCPUserAccountConfig)
                if not self.config.generate_temporary_tokens:
                    config = self.config
                    auth_method = self.auth_method
                    expires_at = None
            elif self.auth_method == GCPAuthenticationMethods.SERVICE_ACCOUNT:
                assert isinstance(self.config, GCPServiceAccountConfig)
                if not self.config.generate_temporary_tokens:
                    config = self.config
                    auth_method = self.auth_method
                    expires_at = None
            elif self.auth_method == GCPAuthenticationMethods.EXTERNAL_ACCOUNT:
                assert isinstance(self.config, GCPExternalAccountConfig)
                if not self.config.generate_temporary_tokens:
                    config = self.config
                    auth_method = self.auth_method
                    expires_at = None

            # Create a client-side GCP connector instance that is fully formed
            # and ready to use to connect to the specified resource (i.e. has
            # all the necessary configuration and credentials, a resource type
            # and a resource ID where applicable)
            return GCPServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
                config=config,
                expires_at=expires_at,
            )

        if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            assert resource_id is not None

            registry_id, _ = self._parse_gar_resource_id(resource_id)

            # Create a client-side Docker connector instance with the temporary
            # Docker credentials
            return DockerServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=DockerAuthenticationMethods.PASSWORD,
                resource_type=resource_type,
                config=DockerConfiguration(
                    username="oauth2accesstoken",
                    password=credentials.token,
                    registry=registry_id,
                ),
                expires_at=expires_at,
            )

        if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            assert resource_id is not None

            cluster_name = self._parse_gke_resource_id(resource_id)

            gke_client = container_v1.ClusterManagerClient(
                credentials=credentials
            )

            # List all GKE clusters
            try:
                clusters = gke_client.list_clusters(
                    parent=f"projects/{self.config.gcp_project_id}/locations/-"
                )
                cluster_map = {
                    cluster.name: cluster for cluster in clusters.clusters
                }
            except google.api_core.exceptions.GoogleAPIError as e:
                msg = f"Failed to list GKE clusters: {e}"
                logger.error(msg)
                raise AuthorizationException(msg) from e

            # Find the cluster with the specified name
            if cluster_name not in cluster_map:
                raise AuthorizationException(
                    f"GKE cluster '{cluster_name}' not found or not "
                    "accessible."
                )

            cluster = cluster_map[cluster_name]

            # get cluster details
            cluster_server = cluster.endpoint
            cluster_ca_cert = cluster.master_auth.cluster_ca_certificate
            bearer_token = credentials.token

            # Create a client-side Kubernetes connector instance with the
            # temporary Kubernetes credentials
            try:
                # Import libraries only when needed
                from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import (
                    KubernetesAuthenticationMethods,
                    KubernetesServiceConnector,
                    KubernetesTokenConfig,
                )
            except ImportError as e:
                raise RuntimeError(
                    f"The Kubernetes Service Connector functionality could not "
                    f"be used due to missing dependencies: {e}"
                )
            return KubernetesServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=KubernetesAuthenticationMethods.TOKEN,
                resource_type=resource_type,
                config=KubernetesTokenConfig(
                    cluster_name=f"gke_{self.config.gcp_project_id}_{cluster_name}",
                    certificate_authority=cluster_ca_cert,
                    server=f"https://{cluster_server}",
                    token=bearer_token,
                ),
                expires_at=expires_at,
            )

        raise ValueError(f"Unsupported resource type: {resource_type}")
get_session(self, auth_method, resource_type=None, resource_id=None)

Get a GCP session object with credentials for the specified resource.

Parameters:

Name Type Description Default
auth_method str

The authentication method to use.

required
resource_type Optional[str]

The resource type to get credentials for.

None
resource_id Optional[str]

The resource ID to get credentials for.

None

Returns:

Type Description
Tuple[google.oauth2.credentials.Credentials, Optional[datetime.datetime]]

GCP session with credentials for the specified resource and its expiration timestamp, if applicable.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def get_session(
    self,
    auth_method: str,
    resource_type: Optional[str] = None,
    resource_id: Optional[str] = None,
) -> Tuple[gcp_credentials.Credentials, Optional[datetime.datetime]]:
    """Get a GCP session object with credentials for the specified resource.

    Args:
        auth_method: The authentication method to use.
        resource_type: The resource type to get credentials for.
        resource_id: The resource ID to get credentials for.

    Returns:
        GCP session with credentials for the specified resource and its
        expiration timestamp, if applicable.
    """
    # We maintain a cache of all sessions to avoid re-authenticating
    # multiple times for the same resource
    key = (auth_method, resource_type, resource_id)
    if key in self._session_cache:
        session, expires_at = self._session_cache[key]
        if expires_at is None:
            return session, None

        # Refresh expired sessions
        now = datetime.datetime.now(datetime.timezone.utc)
        expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
        # check if the token expires in the near future
        if expires_at > now + datetime.timedelta(
            minutes=GCP_SESSION_EXPIRATION_BUFFER
        ):
            return session, expires_at

    logger.debug(
        f"Creating GCP authentication session for auth method "
        f"'{auth_method}', resource type '{resource_type}' and resource ID "
        f"'{resource_id}'..."
    )
    session, expires_at = self._authenticate(
        auth_method, resource_type, resource_id
    )
    self._session_cache[key] = (session, expires_at)
    return session, expires_at
model_post_init(/, self, context)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Parameters:

Name Type Description Default
self BaseModel

The BaseModel instance.

required
context Any

The context.

required
Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
    """This function is meant to behave like a BaseModel method to initialise private attributes.

    It takes context as an argument since that's what pydantic-core passes when calling it.

    Args:
        self: The BaseModel instance.
        context: The context.
    """
    if getattr(self, '__pydantic_private__', None) is None:
        pydantic_private = {}
        for name, private_attr in self.__private_attributes__.items():
            default = private_attr.get_default()
            if default is not PydanticUndefined:
                pydantic_private[name] = default
        object_setattr(self, '__pydantic_private__', pydantic_private)
GCPUserAccountConfig (GCPBaseProjectIDConfig, GCPUserAccountCredentials)

GCP user account configuration.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPUserAccountConfig(GCPBaseProjectIDConfig, GCPUserAccountCredentials):
    """GCP user account configuration."""
GCPUserAccountCredentials (AuthenticationConfig)

GCP user account credentials.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class GCPUserAccountCredentials(AuthenticationConfig):
    """GCP user account credentials."""

    user_account_json: PlainSerializedSecretStr = Field(
        title="GCP User Account Credentials JSON optionally base64 encoded.",
    )

    generate_temporary_tokens: bool = Field(
        default=True,
        title="Generate temporary OAuth 2.0 tokens",
        description="Whether to generate temporary OAuth 2.0 tokens from the "
        "user account credentials JSON. If set to False, the connector will "
        "distribute the user account credentials JSON to clients instead.",
    )

    @model_validator(mode="before")
    @classmethod
    @before_validator_handler
    def validate_user_account_dict(
        cls, data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Convert the user account credentials to JSON if given in dict format.

        Args:
            data: The configuration values.

        Returns:
            The validated configuration values.

        Raises:
            ValueError: If the user account credentials JSON is invalid.
        """
        user_account_json = data.get("user_account_json")
        if isinstance(user_account_json, dict):
            data["user_account_json"] = json.dumps(data["user_account_json"])
        elif isinstance(user_account_json, str):
            # Check if the user account JSON is base64 encoded and decode it
            if re.match(r"^[A-Za-z0-9+/=]+$", user_account_json):
                try:
                    data["user_account_json"] = base64.b64decode(
                        user_account_json
                    ).decode("utf-8")
                except Exception as e:
                    raise ValueError(
                        f"Failed to decode base64 encoded user account JSON: {e}"
                    )
        return data

    @field_validator("user_account_json")
    @classmethod
    def validate_user_account_json(
        cls, value: PlainSerializedSecretStr
    ) -> PlainSerializedSecretStr:
        """Validate the user account credentials JSON.

        Args:
            value: The user account credentials JSON.

        Returns:
            The validated user account credentials JSON.

        Raises:
            ValueError: If the user account credentials JSON is invalid.
        """
        try:
            user_account_info = json.loads(value.get_secret_value())
        except json.JSONDecodeError as e:
            raise ValueError(
                f"GCP user account credentials is not a valid JSON: {e}"
            )

        # Check that all fields are present
        required_fields = [
            "type",
            "refresh_token",
            "client_secret",
            "client_id",
        ]
        # Compute missing fields
        missing_fields = set(required_fields) - set(user_account_info.keys())
        if missing_fields:
            raise ValueError(
                f"GCP user account credentials JSON is missing required "
                f'fields: {", ".join(list(missing_fields))}'
            )

        if user_account_info["type"] != "authorized_user":
            raise ValueError(
                "The JSON does not contain GCP user account credentials. The "
                f'"type" field is set to {user_account_info["type"]} '
                "instead of 'authorized_user'."
            )

        return value
validate_user_account_dict(data, validation_info) classmethod

Wrapper method to handle the raw data.

Parameters:

Name Type Description Default
cls

the class handler

required
data Any

the raw input data

required
validation_info ValidationInfo

the context of the validation.

required

Returns:

Type Description
Any

the validated data

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def before_validator(
    cls: Type[BaseModel], data: Any, validation_info: ValidationInfo
) -> Any:
    """Wrapper method to handle the raw data.

    Args:
        cls: the class handler
        data: the raw input data
        validation_info: the context of the validation.

    Returns:
        the validated data
    """
    data = model_validator_data_handler(
        raw_data=data, base_class=cls, validation_info=validation_info
    )
    return method(cls=cls, data=data)
validate_user_account_json(value) classmethod

Validate the user account credentials JSON.

Parameters:

Name Type Description Default
value Annotated[pydantic.types.SecretStr, PlainSerializer(func=<function <lambda> at 0x7fa8b730b9c0>, return_type=PydanticUndefined, when_used='json')]

The user account credentials JSON.

required

Returns:

Type Description
Annotated[pydantic.types.SecretStr, PlainSerializer(func=<function <lambda> at 0x7fa8b730b9c0>, return_type=PydanticUndefined, when_used='json')]

The validated user account credentials JSON.

Exceptions:

Type Description
ValueError

If the user account credentials JSON is invalid.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
@field_validator("user_account_json")
@classmethod
def validate_user_account_json(
    cls, value: PlainSerializedSecretStr
) -> PlainSerializedSecretStr:
    """Validate the user account credentials JSON.

    Args:
        value: The user account credentials JSON.

    Returns:
        The validated user account credentials JSON.

    Raises:
        ValueError: If the user account credentials JSON is invalid.
    """
    try:
        user_account_info = json.loads(value.get_secret_value())
    except json.JSONDecodeError as e:
        raise ValueError(
            f"GCP user account credentials is not a valid JSON: {e}"
        )

    # Check that all fields are present
    required_fields = [
        "type",
        "refresh_token",
        "client_secret",
        "client_id",
    ]
    # Compute missing fields
    missing_fields = set(required_fields) - set(user_account_info.keys())
    if missing_fields:
        raise ValueError(
            f"GCP user account credentials JSON is missing required "
            f'fields: {", ".join(list(missing_fields))}'
        )

    if user_account_info["type"] != "authorized_user":
        raise ValueError(
            "The JSON does not contain GCP user account credentials. The "
            f'"type" field is set to {user_account_info["type"]} '
            "instead of 'authorized_user'."
        )

    return value
ZenMLAwsSecurityCredentialsSupplier (_DefaultAwsSecurityCredentialsSupplier)

An improved version of the GCP external account credential supplier for AWS.

The original GCP external account credential supplier only provides rudimentary support for extracting AWS credentials from environment variables or the AWS metadata service. This version improves on that by using the boto3 library itself (if available), which uses the entire range of implicit authentication features packed into it.

Without this improvement, sts.AssumeRoleWithWebIdentity authentication is not supported for EKS pods and the EC2 attached role credentials are used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class ZenMLAwsSecurityCredentialsSupplier(
    _DefaultAwsSecurityCredentialsSupplier  # type: ignore[misc]
):
    """An improved version of the GCP external account credential supplier for AWS.

    The original GCP external account credential supplier only provides
    rudimentary support for extracting AWS credentials from environment
    variables or the AWS metadata service. This version improves on that by
    using the boto3 library itself (if available), which uses the entire range
    of implicit authentication features packed into it.

    Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is
    not supported for EKS pods and the EC2 attached role credentials are
    used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).
    """

    def get_aws_security_credentials(
        self, context: Any, request: Any
    ) -> gcp_aws.AwsSecurityCredentials:
        """Get the security credentials from the local environment.

        This method is a copy of the original method from the
        `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
        been modified to use the boto3 library to extract the AWS credentials
        from the local environment.

        Args:
            context: The context to use to get the security credentials.
            request: The request to use to get the security credentials.

        Returns:
            The AWS temporary security credentials.
        """
        try:
            import boto3

            session = boto3.Session()
            credentials = session.get_credentials()
            if credentials is not None:
                creds = credentials.get_frozen_credentials()
                return gcp_aws.AwsSecurityCredentials(
                    creds.access_key,
                    creds.secret_key,
                    creds.token,
                )
        except ImportError:
            pass

        logger.debug(
            "Failed to extract AWS credentials from the local environment "
            "using the boto3 library. Falling back to the original "
            "implementation."
        )

        return super().get_aws_security_credentials(context, request)

    def get_aws_region(self, context: Any, request: Any) -> str:
        """Get the AWS region from the local environment.

        This method is a copy of the original method from the
        `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
        been modified to use the boto3 library to extract the AWS
        region from the local environment.

        Args:
            context: The context to use to get the security credentials.
            request: The request to use to get the security credentials.

        Returns:
            The AWS region.
        """
        try:
            import boto3

            session = boto3.Session()
            if session.region_name:
                return session.region_name  # type: ignore[no-any-return]
        except ImportError:
            pass

        logger.debug(
            "Failed to extract AWS region from the local environment "
            "using the boto3 library. Falling back to the original "
            "implementation."
        )

        return super().get_aws_region(  # type: ignore[no-any-return]
            context, request
        )
get_aws_region(self, context, request)

Get the AWS region from the local environment.

This method is a copy of the original method from the google.auth.aws._DefaultAwsSecurityCredentialsSupplier class. It has been modified to use the boto3 library to extract the AWS region from the local environment.

Parameters:

Name Type Description Default
context Any

The context to use to get the security credentials.

required
request Any

The request to use to get the security credentials.

required

Returns:

Type Description
str

The AWS region.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def get_aws_region(self, context: Any, request: Any) -> str:
    """Get the AWS region from the local environment.

    This method is a copy of the original method from the
    `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
    been modified to use the boto3 library to extract the AWS
    region from the local environment.

    Args:
        context: The context to use to get the security credentials.
        request: The request to use to get the security credentials.

    Returns:
        The AWS region.
    """
    try:
        import boto3

        session = boto3.Session()
        if session.region_name:
            return session.region_name  # type: ignore[no-any-return]
    except ImportError:
        pass

    logger.debug(
        "Failed to extract AWS region from the local environment "
        "using the boto3 library. Falling back to the original "
        "implementation."
    )

    return super().get_aws_region(  # type: ignore[no-any-return]
        context, request
    )
get_aws_security_credentials(self, context, request)

Get the security credentials from the local environment.

This method is a copy of the original method from the google.auth.aws._DefaultAwsSecurityCredentialsSupplier class. It has been modified to use the boto3 library to extract the AWS credentials from the local environment.

Parameters:

Name Type Description Default
context Any

The context to use to get the security credentials.

required
request Any

The request to use to get the security credentials.

required

Returns:

Type Description
google.auth.aws.AwsSecurityCredentials

The AWS temporary security credentials.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
def get_aws_security_credentials(
    self, context: Any, request: Any
) -> gcp_aws.AwsSecurityCredentials:
    """Get the security credentials from the local environment.

    This method is a copy of the original method from the
    `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
    been modified to use the boto3 library to extract the AWS credentials
    from the local environment.

    Args:
        context: The context to use to get the security credentials.
        request: The request to use to get the security credentials.

    Returns:
        The AWS temporary security credentials.
    """
    try:
        import boto3

        session = boto3.Session()
        credentials = session.get_credentials()
        if credentials is not None:
            creds = credentials.get_frozen_credentials()
            return gcp_aws.AwsSecurityCredentials(
                creds.access_key,
                creds.secret_key,
                creds.token,
            )
    except ImportError:
        pass

    logger.debug(
        "Failed to extract AWS credentials from the local environment "
        "using the boto3 library. Falling back to the original "
        "implementation."
    )

    return super().get_aws_security_credentials(context, request)
ZenMLGCPAWSExternalAccountCredentials (Credentials)

An improved version of the GCP external account credential for AWS.

The original GCP external account credential only provides rudimentary support for extracting AWS credentials from environment variables or the AWS metadata service. This version improves on that by using the boto3 library itself (if available), which uses the entire range of implicit authentication features packed into it.

Without this improvement, sts.AssumeRoleWithWebIdentity authentication is not supported for EKS pods and the EC2 attached role credentials are used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).

IMPORTANT: subclassing this class only works with the google-auth library version lower than 2.29.0. Starting from version 2.29.0, the AWS logic has been moved to a separate google.auth.aws._DefaultAwsSecurityCredentialsSupplier class that can be subclassed instead and supplied as the aws_security_credentials_supplier parameter to the google.auth.aws.Credentials class.

Source code in zenml/integrations/gcp/service_connectors/gcp_service_connector.py
class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials):  # type: ignore[misc]
    """An improved version of the GCP external account credential for AWS.

    The original GCP external account credential only provides rudimentary
    support for extracting AWS credentials from environment variables or the
    AWS metadata service. This version improves on that by using the boto3
    library itself (if available), which uses the entire range of implicit
    authentication features packed into it.

    Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is
    not supported for EKS pods and the EC2 attached role credentials are
    used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).

    IMPORTANT: subclassing this class only works with the `google-auth` library
    version lower than 2.29.0. Starting from version 2.29.0, the AWS logic
    has been moved to a separate `google.auth.aws._DefaultAwsSecurityCredentialsSupplier`
    class that can be subclassed instead and supplied as the
    `aws_security_credentials_supplier` parameter to the
    `google.auth.aws.Credentials` class.
    """

    def _get_security_credentials(
        self, request: Any, imdsv2_session_token: Any
    ) -> Dict[str, Any]:
        """Get the security credentials from the local environment.

        This method is a copy of the original method from the
        `google.auth._default` module. It has been modified to use the boto3
        library to extract the AWS credentials from the local environment.

        Args:
            request: The request to use to get the security credentials.
            imdsv2_session_token: The IMDSv2 session token to use to get the
                security credentials.

        Returns:
            The AWS temporary security credentials.
        """
        try:
            import boto3

            session = boto3.Session()
            credentials = session.get_credentials()
            if credentials is not None:
                creds = credentials.get_frozen_credentials()
                return {
                    "access_key_id": creds.access_key,
                    "secret_access_key": creds.secret_key,
                    "security_token": creds.token,
                }
        except ImportError:
            pass

        logger.debug(
            "Failed to extract AWS credentials from the local environment "
            "using the boto3 library. Falling back to the original "
            "implementation."
        )

        return super()._get_security_credentials(  # type: ignore[no-any-return]
            request, imdsv2_session_token
        )

step_operators special

Initialization for the VertexAI Step Operator.

vertex_step_operator

Implementation of a VertexAI step operator.

Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py

VertexStepOperator (BaseStepOperator, GoogleCredentialsMixin)

Step operator to run a step on Vertex AI.

This class defines code that can set up a Vertex AI environment and run the ZenML entrypoint command in it.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
    """Step operator to run a step on Vertex AI.

    This class defines code that can set up a Vertex AI environment and run the
    ZenML entrypoint command in it.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initializes the step operator and validates the accelerator type.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)

    @property
    def config(self) -> VertexStepOperatorConfig:
        """Returns the `VertexStepOperatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(VertexStepOperatorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Vertex step operator.

        Returns:
            The settings class.
        """
        return VertexStepOperatorSettings

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        Returns:
            A validator that checks that the stack contains a remote container
            registry and a remote artifact store.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            if stack.artifact_store.config.is_local:
                return False, (
                    "The Vertex step operator runs code remotely and "
                    "needs to write files into the artifact store, but the "
                    f"artifact store `{stack.artifact_store.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote artifact store when using the Vertex "
                    "step operator."
                )

            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.config.is_local:
                return False, (
                    "The Vertex step operator runs code remotely and "
                    "needs to push/pull Docker images, but the "
                    f"container registry `{container_registry.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote container registry when using the "
                    "Vertex step operator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_docker_builds(
        self, deployment: "PipelineDeploymentBase"
    ) -> List["BuildConfiguration"]:
        """Gets the Docker builds required for the component.

        Args:
            deployment: The pipeline deployment for which to get the builds.

        Returns:
            The required Docker builds.
        """
        builds = []
        for step_name, step in deployment.step_configurations.items():
            if step.config.step_operator == self.name:
                build = BuildConfiguration(
                    key=VERTEX_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                )
                builds.append(build)

        return builds

    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
        environment: Dict[str, str],
    ) -> None:
        """Launches a step on VertexAI.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
            environment: Environment variables to set in the step operator
                environment.

        Raises:
            RuntimeError: If the run fails.
        """
        resource_settings = info.config.resource_settings
        if resource_settings.cpu_count or resource_settings.memory:
            logger.warning(
                "Specifying cpus or memory is not supported for "
                "the Vertex step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different machine_type type like this: "
                "`zenml step-operator update %s "
                "--machine_type=<MACHINE_TYPE>`",
                self.name,
            )
        settings = cast(VertexStepOperatorSettings, self.get_settings(info))
        validate_accelerator_type(settings.accelerator_type)

        job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}

        # Step 1: Authenticate with Google
        credentials, project_id = self._get_authentication()

        image_name = info.get_image(key=VERTEX_DOCKER_IMAGE_KEY)

        # Step 3: Launch the job
        # The AI Platform services require regional API endpoints.
        client_options = {
            "api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
        }
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for
        # multiple requests.
        client = aiplatform.gapic.JobServiceClient(
            credentials=credentials, client_options=client_options
        )
        accelerator_count = (
            resource_settings.gpu_count or settings.accelerator_count
        )
        custom_job = {
            "display_name": info.run_name,
            "job_spec": {
                "worker_pool_specs": [
                    {
                        "machine_spec": {
                            "machine_type": settings.machine_type,
                            "accelerator_type": settings.accelerator_type,
                            "accelerator_count": accelerator_count
                            if settings.accelerator_type
                            else 0,
                        },
                        "replica_count": 1,
                        "container_spec": {
                            "image_uri": image_name,
                            "command": entrypoint_command,
                            "args": [],
                            "env": [
                                {"name": key, "value": value}
                                for key, value in environment.items()
                            ],
                        },
                        "disk_spec": {
                            "boot_disk_type": settings.boot_disk_type,
                            "boot_disk_size_gb": settings.boot_disk_size_gb,
                        },
                    }
                ],
                "service_account": self.config.service_account,
                "network": self.config.network,
                "reserved_ip_ranges": (
                    self.config.reserved_ip_ranges.split(",")
                    if self.config.reserved_ip_ranges
                    else []
                ),
            },
            "labels": job_labels,
            "encryption_spec": {
                "kmsKeyName": self.config.encryption_spec_key_name
            }
            if self.config.encryption_spec_key_name
            else {},
        }
        logger.debug("Vertex AI Job=%s", custom_job)

        parent = f"projects/{project_id}/locations/{self.config.region}"
        logger.info(
            "Submitting custom job='%s', path='%s' to Vertex AI Training.",
            custom_job["display_name"],
            parent,
        )
        info.force_write_logs()
        response = client.create_custom_job(
            parent=parent, custom_job=custom_job
        )
        logger.debug("Vertex AI response:", response)

        # Step 4: Monitor the job

        # Monitors the long-running operation by polling the job state
        # periodically, and retries the polling when a transient connectivity
        # issue is encountered.
        #
        # Long-running operation monitoring:
        #   The possible states of "get job" response can be found at
        #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
        #   where SUCCEEDED/FAILED/CANCELED are considered to be final states.
        #   The following logic will keep polling the state of the job until
        #   the job enters a final state.
        #
        # During the polling, if a connection error was encountered, the GET
        # request will be retried by recreating the Python API client to
        # refresh the lifecycle of the connection being used. See
        # https://github.com/googleapis/google-api-python-client/issues/218
        # for a detailed description of the problem. If the error persists for
        # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
        # will raise ConnectionError.
        retry_count = 0
        job_id = response.name

        while response.state not in VERTEX_JOB_STATES_COMPLETED:
            time.sleep(POLLING_INTERVAL_IN_SECONDS)
            try:
                response = client.get_custom_job(name=job_id)
                retry_count = 0
            # Handle transient connection errors and credential expiration by
            # recreating the Python API client.
            except (ConnectionError, ServerError) as err:
                if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
                    retry_count += 1
                    logger.warning(
                        f"Error encountered when polling job "
                        f"{job_id}: {err}\nRetrying...",
                    )
                    # This call will refresh the credentials if they expired.
                    credentials, project_id = self._get_authentication()
                    # Recreate the Python API client.
                    client = aiplatform.gapic.JobServiceClient(
                        credentials=credentials, client_options=client_options
                    )
                else:
                    logger.exception(
                        "Request failed after %s retries.",
                        CONNECTION_ERROR_RETRY_LIMIT,
                    )
                    raise RuntimeError(
                        f"Request failed after {CONNECTION_ERROR_RETRY_LIMIT} "
                        f"retries: {err}"
                    )
            if response.state in VERTEX_JOB_STATES_FAILED:
                err_msg = (
                    "Job '{}' did not succeed.  Detailed response {}.".format(
                        job_id, response
                    )
                )
                logger.error(err_msg)
                raise RuntimeError(err_msg)

        # Cloud training complete
        logger.info("Job '%s' successful.", job_id)
config: VertexStepOperatorConfig property readonly

Returns the VertexStepOperatorConfig config.

Returns:

Type Description
VertexStepOperatorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Vertex step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a remote container registry and a remote artifact store.

__init__(self, *args, **kwargs) special

Initializes the step operator and validates the accelerator type.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initializes the step operator and validates the accelerator type.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
get_docker_builds(self, deployment)

Gets the Docker builds required for the component.

Parameters:

Name Type Description Default
deployment PipelineDeploymentBase

The pipeline deployment for which to get the builds.

required

Returns:

Type Description
List[BuildConfiguration]

The required Docker builds.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def get_docker_builds(
    self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
    """Gets the Docker builds required for the component.

    Args:
        deployment: The pipeline deployment for which to get the builds.

    Returns:
        The required Docker builds.
    """
    builds = []
    for step_name, step in deployment.step_configurations.items():
        if step.config.step_operator == self.name:
            build = BuildConfiguration(
                key=VERTEX_DOCKER_IMAGE_KEY,
                settings=step.config.docker_settings,
                step_name=step_name,
            )
            builds.append(build)

    return builds
launch(self, info, entrypoint_command, environment)

Launches a step on VertexAI.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
environment Dict[str, str]

Environment variables to set in the step operator environment.

required

Exceptions:

Type Description
RuntimeError

If the run fails.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Launches a step on VertexAI.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
        environment: Environment variables to set in the step operator
            environment.

    Raises:
        RuntimeError: If the run fails.
    """
    resource_settings = info.config.resource_settings
    if resource_settings.cpu_count or resource_settings.memory:
        logger.warning(
            "Specifying cpus or memory is not supported for "
            "the Vertex step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different machine_type type like this: "
            "`zenml step-operator update %s "
            "--machine_type=<MACHINE_TYPE>`",
            self.name,
        )
    settings = cast(VertexStepOperatorSettings, self.get_settings(info))
    validate_accelerator_type(settings.accelerator_type)

    job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}

    # Step 1: Authenticate with Google
    credentials, project_id = self._get_authentication()

    image_name = info.get_image(key=VERTEX_DOCKER_IMAGE_KEY)

    # Step 3: Launch the job
    # The AI Platform services require regional API endpoints.
    client_options = {
        "api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
    }
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for
    # multiple requests.
    client = aiplatform.gapic.JobServiceClient(
        credentials=credentials, client_options=client_options
    )
    accelerator_count = (
        resource_settings.gpu_count or settings.accelerator_count
    )
    custom_job = {
        "display_name": info.run_name,
        "job_spec": {
            "worker_pool_specs": [
                {
                    "machine_spec": {
                        "machine_type": settings.machine_type,
                        "accelerator_type": settings.accelerator_type,
                        "accelerator_count": accelerator_count
                        if settings.accelerator_type
                        else 0,
                    },
                    "replica_count": 1,
                    "container_spec": {
                        "image_uri": image_name,
                        "command": entrypoint_command,
                        "args": [],
                        "env": [
                            {"name": key, "value": value}
                            for key, value in environment.items()
                        ],
                    },
                    "disk_spec": {
                        "boot_disk_type": settings.boot_disk_type,
                        "boot_disk_size_gb": settings.boot_disk_size_gb,
                    },
                }
            ],
            "service_account": self.config.service_account,
            "network": self.config.network,
            "reserved_ip_ranges": (
                self.config.reserved_ip_ranges.split(",")
                if self.config.reserved_ip_ranges
                else []
            ),
        },
        "labels": job_labels,
        "encryption_spec": {
            "kmsKeyName": self.config.encryption_spec_key_name
        }
        if self.config.encryption_spec_key_name
        else {},
    }
    logger.debug("Vertex AI Job=%s", custom_job)

    parent = f"projects/{project_id}/locations/{self.config.region}"
    logger.info(
        "Submitting custom job='%s', path='%s' to Vertex AI Training.",
        custom_job["display_name"],
        parent,
    )
    info.force_write_logs()
    response = client.create_custom_job(
        parent=parent, custom_job=custom_job
    )
    logger.debug("Vertex AI response:", response)

    # Step 4: Monitor the job

    # Monitors the long-running operation by polling the job state
    # periodically, and retries the polling when a transient connectivity
    # issue is encountered.
    #
    # Long-running operation monitoring:
    #   The possible states of "get job" response can be found at
    #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
    #   where SUCCEEDED/FAILED/CANCELED are considered to be final states.
    #   The following logic will keep polling the state of the job until
    #   the job enters a final state.
    #
    # During the polling, if a connection error was encountered, the GET
    # request will be retried by recreating the Python API client to
    # refresh the lifecycle of the connection being used. See
    # https://github.com/googleapis/google-api-python-client/issues/218
    # for a detailed description of the problem. If the error persists for
    # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
    # will raise ConnectionError.
    retry_count = 0
    job_id = response.name

    while response.state not in VERTEX_JOB_STATES_COMPLETED:
        time.sleep(POLLING_INTERVAL_IN_SECONDS)
        try:
            response = client.get_custom_job(name=job_id)
            retry_count = 0
        # Handle transient connection errors and credential expiration by
        # recreating the Python API client.
        except (ConnectionError, ServerError) as err:
            if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
                retry_count += 1
                logger.warning(
                    f"Error encountered when polling job "
                    f"{job_id}: {err}\nRetrying...",
                )
                # This call will refresh the credentials if they expired.
                credentials, project_id = self._get_authentication()
                # Recreate the Python API client.
                client = aiplatform.gapic.JobServiceClient(
                    credentials=credentials, client_options=client_options
                )
            else:
                logger.exception(
                    "Request failed after %s retries.",
                    CONNECTION_ERROR_RETRY_LIMIT,
                )
                raise RuntimeError(
                    f"Request failed after {CONNECTION_ERROR_RETRY_LIMIT} "
                    f"retries: {err}"
                )
        if response.state in VERTEX_JOB_STATES_FAILED:
            err_msg = (
                "Job '{}' did not succeed.  Detailed response {}.".format(
                    job_id, response
                )
            )
            logger.error(err_msg)
            raise RuntimeError(err_msg)

    # Cloud training complete
    logger.info("Job '%s' successful.", job_id)
validate_accelerator_type(accelerator_type=None)

Validates that the accelerator type is valid.

Parameters:

Name Type Description Default
accelerator_type Optional[str]

The accelerator type to validate.

None

Exceptions:

Type Description
ValueError

If the accelerator type is not valid.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None:
    """Validates that the accelerator type is valid.

    Args:
        accelerator_type: The accelerator type to validate.

    Raises:
        ValueError: If the accelerator type is not valid.
    """
    accepted_vals = list(aiplatform.gapic.AcceleratorType.__members__.keys())
    if accelerator_type and accelerator_type.upper() not in accepted_vals:
        raise ValueError(
            f"Accelerator must be one of the following: {accepted_vals}"
        )