Skip to content

Zen Server

zenml.zen_server special

ZenML Server Implementation.

The ZenML Server is a centralized service meant for use in a collaborative setting in which stacks, stack components, flavors, pipeline and pipeline runs can be shared over the network with other users.

You can use the zenml server up command to spin up ZenML server instances that are either running locally as daemon processes or docker containers, or to deploy a ZenML server remotely on a managed cloud platform. The other CLI commands in the same zenml server group can be used to manage the server instances deployed from your local machine.

To connect the local ZenML client to one of the managed ZenML servers, call zenml server connect with the name of the server you want to connect to.

auth

Authentication module for ZenML server.

AuthContext (BaseModel)

The authentication context.

Source code in zenml/zen_server/auth.py
class AuthContext(BaseModel):
    """The authentication context."""

    user: UserResponse
    access_token: Optional[JWTToken] = None
    encoded_access_token: Optional[str] = None
    device: Optional[OAuthDeviceInternalResponse] = None
    api_key: Optional[APIKeyInternalResponse] = None

CookieOAuth2TokenBearer (OAuth2PasswordBearer)

OAuth2 token bearer authentication scheme that uses a cookie.

Source code in zenml/zen_server/auth.py
class CookieOAuth2TokenBearer(OAuth2PasswordBearer):
    """OAuth2 token bearer authentication scheme that uses a cookie."""

    async def __call__(self, request: Request) -> Optional[str]:
        """Extract the bearer token from the request.

        Args:
            request: The request.

        Returns:
            The bearer token extracted from the request cookie or header.
        """
        # First, try to get the token from the cookie
        authorization = request.cookies.get(
            server_config().get_auth_cookie_name()
        )
        if authorization:
            logger.info("Got token from cookie")
            return authorization

        # If the token is not present in the cookie, try to get it from the
        # Authorization header
        return await super().__call__(request)
__call__(self, request) async special

Extract the bearer token from the request.

Parameters:

Name Type Description Default
request Request

The request.

required

Returns:

Type Description
Optional[str]

The bearer token extracted from the request cookie or header.

Source code in zenml/zen_server/auth.py
async def __call__(self, request: Request) -> Optional[str]:
    """Extract the bearer token from the request.

    Args:
        request: The request.

    Returns:
        The bearer token extracted from the request cookie or header.
    """
    # First, try to get the token from the cookie
    authorization = request.cookies.get(
        server_config().get_auth_cookie_name()
    )
    if authorization:
        logger.info("Got token from cookie")
        return authorization

    # If the token is not present in the cookie, try to get it from the
    # Authorization header
    return await super().__call__(request)

authenticate_api_key(api_key)

Implement service account API key authentication.

Parameters:

Name Type Description Default
api_key str

The service account API key.

required

Returns:

Type Description
AuthContext

The authentication context reflecting the authenticated service account.

Exceptions:

Type Description
AuthorizationException

If the service account could not be authorized.

Source code in zenml/zen_server/auth.py
def authenticate_api_key(
    api_key: str,
) -> AuthContext:
    """Implement service account API key authentication.

    Args:
        api_key: The service account API key.


    Returns:
        The authentication context reflecting the authenticated service account.

    Raises:
        AuthorizationException: If the service account could not be authorized.
    """
    try:
        decoded_api_key = APIKey.decode_api_key(api_key)
    except ValueError:
        error = "Authentication error: error decoding API key"
        logger.exception(error)
        raise AuthorizationException(error)

    internal_api_key = _fetch_and_verify_api_key(
        api_key_id=decoded_api_key.id, key_to_verify=decoded_api_key.key
    )

    # For now, a lot of code still relies on the active user in the auth
    # context being a UserResponse object, which is a superset of the
    # ServiceAccountResponse object. So we need to convert the service
    # account to a user here.
    user_model = internal_api_key.service_account.to_user_model()
    return AuthContext(user=user_model, api_key=internal_api_key)

authenticate_credentials(user_name_or_id=None, password=None, access_token=None, activation_token=None)

Verify if user authentication credentials are valid.

This function can be used to validate all supplied user credentials to cover a range of possibilities:

  • username only - only when the no-auth scheme is used
  • username+password - for basic HTTP authentication or the OAuth2 password grant
  • access token (with embedded user id) - after successful authentication using one of the supported grants
  • username+activation token - for user activation

Parameters:

Name Type Description Default
user_name_or_id Union[str, uuid.UUID]

The username or user ID.

None
password Optional[str]

The password.

None
access_token Optional[str]

The access token.

None
activation_token Optional[str]

The activation token.

None

Returns:

Type Description
AuthContext

The authenticated account details.

Exceptions:

Type Description
AuthorizationException

If the credentials are invalid.

Source code in zenml/zen_server/auth.py
def authenticate_credentials(
    user_name_or_id: Optional[Union[str, UUID]] = None,
    password: Optional[str] = None,
    access_token: Optional[str] = None,
    activation_token: Optional[str] = None,
) -> AuthContext:
    """Verify if user authentication credentials are valid.

    This function can be used to validate all supplied user credentials to
    cover a range of possibilities:

     * username only - only when the no-auth scheme is used
     * username+password - for basic HTTP authentication or the OAuth2 password
       grant
     * access token (with embedded user id) - after successful authentication
       using one of the supported grants
     * username+activation token - for user activation

    Args:
        user_name_or_id: The username or user ID.
        password: The password.
        access_token: The access token.
        activation_token: The activation token.

    Returns:
        The authenticated account details.

    Raises:
        AuthorizationException: If the credentials are invalid.
    """
    user: Optional[UserAuthModel] = None
    auth_context: Optional[AuthContext] = None
    if user_name_or_id:
        try:
            # NOTE: this method will not return a user if the user name or ID
            # identifies a service account instead of a regular user. This
            # is intentional because service accounts are not allowed to
            # be used to authenticate to the API using a username and password,
            # or an activation token.
            user = zen_store().get_auth_user(user_name_or_id)
            user_model = zen_store().get_user(
                user_name_or_id=user_name_or_id, include_private=True
            )
            auth_context = AuthContext(user=user_model)
        except KeyError:
            # even when the user does not exist, we still want to execute the
            # password/token verification to protect against response discrepancy
            # attacks (https://cwe.mitre.org/data/definitions/204.html)
            logger.exception(
                f"Authentication error: error retrieving account "
                f"{user_name_or_id}"
            )
            pass

    if password is not None:
        if not UserAuthModel.verify_password(password, user):
            error = "Authentication error: invalid username or password"
            logger.error(error)
            raise AuthorizationException(error)
        if user and not user.active:
            error = f"Authentication error: user {user.name} is not active"
            logger.error(error)
            raise AuthorizationException(error)

    elif activation_token is not None:
        if not UserAuthModel.verify_activation_token(activation_token, user):
            error = (
                f"Authentication error: invalid activation token for user "
                f"{user_name_or_id}"
            )
            logger.error(error)
            raise AuthorizationException(error)

    elif access_token is not None:
        try:
            decoded_token = JWTToken.decode_token(
                token=access_token,
            )
        except AuthorizationException as e:
            error = f"Authentication error: error decoding access token: {e}."
            logger.exception(error)
            raise AuthorizationException(error)

        try:
            user_model = zen_store().get_user(
                user_name_or_id=decoded_token.user_id, include_private=True
            )
        except KeyError:
            error = (
                f"Authentication error: error retrieving token account "
                f"{decoded_token.user_id}"
            )
            logger.error(error)
            raise AuthorizationException(error)

        if not user_model.active:
            error = (
                f"Authentication error: account {user_model.name} is not "
                f"active"
            )
            logger.error(error)
            raise AuthorizationException(error)

        api_key_model: Optional[APIKeyInternalResponse] = None
        if decoded_token.api_key_id:
            # The API token was generated from an API key. We still have to
            # verify if the API key hasn't been deactivated or deleted in the
            # meantime.
            api_key_model = _fetch_and_verify_api_key(decoded_token.api_key_id)

        device_model: Optional[OAuthDeviceInternalResponse] = None
        if decoded_token.device_id:
            # Access tokens that have been issued for a device are only valid
            # for that device, so we need to check if the device ID matches any
            # of the valid devices in the database.
            try:
                device_model = zen_store().get_internal_authorized_device(
                    device_id=decoded_token.device_id
                )
            except KeyError:
                error = (
                    f"Authentication error: error retrieving token device "
                    f"{decoded_token.device_id}"
                )
                logger.error(error)
                raise AuthorizationException(error)

            if (
                device_model.user is None
                or device_model.user.id != user_model.id
            ):
                error = (
                    f"Authentication error: device {decoded_token.device_id} "
                    f"does not belong to user {user_model.name}"
                )
                logger.error(error)
                raise AuthorizationException(error)

            if device_model.status != OAuthDeviceStatus.ACTIVE:
                error = (
                    f"Authentication error: device {decoded_token.device_id} "
                    f"is not active"
                )
                logger.error(error)
                raise AuthorizationException(error)

            if (
                device_model.expires
                and datetime.utcnow() >= device_model.expires
            ):
                error = (
                    f"Authentication error: device {decoded_token.device_id} "
                    "has expired"
                )
                logger.error(error)
                raise AuthorizationException(error)

            zen_store().update_internal_authorized_device(
                device_id=device_model.id,
                update=OAuthDeviceInternalUpdate(
                    update_last_login=True,
                ),
            )

        auth_context = AuthContext(
            user=user_model,
            access_token=decoded_token,
            encoded_access_token=access_token,
            device=device_model,
            api_key=api_key_model,
        )

    else:
        # IMPORTANT: the ONLY way we allow the authentication process to
        # continue without any credentials (i.e. no password, activation
        # token or access token) is if authentication is explicitly disabled
        # by setting the auth_scheme to NO_AUTH.
        if server_config().auth_scheme != AuthScheme.NO_AUTH:
            error = "Authentication error: no credentials provided"
            logger.error(error)
            raise AuthorizationException(error)

    if not auth_context:
        error = "Authentication error: invalid credentials"
        logger.error(error)
        raise AuthorizationException(error)

    return auth_context

authenticate_device(client_id, device_code)

Verify if device authorization credentials are valid.

Parameters:

Name Type Description Default
client_id UUID

The OAuth2 client ID.

required
device_code str

The device code.

required

Returns:

Type Description
AuthContext

The authenticated account details.

Exceptions:

Type Description
OAuthError

If the device authorization credentials are invalid.

Source code in zenml/zen_server/auth.py
def authenticate_device(client_id: UUID, device_code: str) -> AuthContext:
    """Verify if device authorization credentials are valid.

    Args:
        client_id: The OAuth2 client ID.
        device_code: The device code.

    Returns:
        The authenticated account details.

    Raises:
        OAuthError: If the device authorization credentials are invalid.
    """
    # This is the part of the OAuth2 device code grant flow where a client
    # device is continuously polling the server to check if the user has
    # authorized a device. The following needs to happen to successfully
    # authenticate the device and return a valid access token:
    #
    # 1. the device code and client ID must match a device in the DB
    # 2. the device must be in the VERIFIED state, meaning that the user
    # has successfully authorized the device via the user code but the
    # device client hasn't yet fetched the associated API access token yet.
    # 3. the device must not be expired

    config = server_config()
    store = zen_store()

    try:
        device_model = store.get_internal_authorized_device(
            client_id=client_id
        )
    except KeyError:
        error = (
            f"Authentication error: error retrieving device with client ID "
            f"{client_id}"
        )
        logger.error(error)
        raise OAuthError(
            error="invalid_client",
            error_description=error,
        )

    if device_model.status != OAuthDeviceStatus.VERIFIED:
        error = (
            f"Authentication error: device with client ID {client_id} is "
            f"{device_model.status.value}."
        )
        logger.error(error)
        if device_model.status == OAuthDeviceStatus.PENDING:
            oauth_error = "authorization_pending"
        elif device_model.status == OAuthDeviceStatus.LOCKED:
            oauth_error = "access_denied"
        else:
            oauth_error = "expired_token"
        raise OAuthError(
            error=oauth_error,
            error_description=error,
        )

    if device_model.expires and datetime.utcnow() >= device_model.expires:
        error = (
            f"Authentication error: device for client ID {client_id} has "
            "expired"
        )
        logger.error(error)
        raise OAuthError(
            error="expired_token",
            error_description=error,
        )

    # Check the device code
    if not device_model.verify_device_code(device_code):
        # If the device code is invalid, increment the failed auth attempts
        # counter and lock the device if the maximum number of failed auth
        # attempts has been reached.
        failed_auth_attempts = device_model.failed_auth_attempts + 1
        update = OAuthDeviceInternalUpdate(
            failed_auth_attempts=failed_auth_attempts
        )
        if failed_auth_attempts >= config.max_failed_device_auth_attempts:
            update.locked = True

        store.update_internal_authorized_device(
            device_id=device_model.id,
            update=update,
        )

        if failed_auth_attempts >= config.max_failed_device_auth_attempts:
            error = (
                f"Authentication error: device for client ID {client_id} "
                "has been locked due to too many failed authentication "
                "attempts."
            )
        else:
            error = (
                f"Authentication error: device for client ID {client_id} "
                "has an invalid device code."
            )

        logger.error(error)
        raise OAuthError(
            error="access_denied",
            error_description=error,
        )

    # The device is valid, so we can return the user associated with it.
    # This is the one and only time we return an AuthContext authorized by
    # a device code in order to be exchanged for an access token. Subsequent
    # requests to the API will be authenticated using the access token.
    #
    # Update the device state to ACTIVE and set an expiration date for it
    # past which it can no longer be used for authentication. The expiration
    # date also determines the expiration date of the access token issued
    # for this device.
    expires_in: int = 0
    if config.jwt_token_expire_minutes:
        if device_model.trusted_device:
            expires_in = config.trusted_device_expiration_minutes or 0
        else:
            expires_in = config.device_expiration_minutes or 0

    update = OAuthDeviceInternalUpdate(
        status=OAuthDeviceStatus.ACTIVE,
        expires_in=expires_in * 60,
    )
    device_model = zen_store().update_internal_authorized_device(
        device_id=device_model.id,
        update=update,
    )

    # This can never happen because the VERIFIED state is only set if
    # a user verified and has been associated with the device.
    assert device_model.user is not None

    return AuthContext(user=device_model.user, device=device_model)

authenticate_external_user(external_access_token)

Implement external authentication.

Parameters:

Name Type Description Default
external_access_token str

The access token used to authenticate the user to the external authenticator.

required

Returns:

Type Description
AuthContext

The authentication context reflecting the authenticated user.

Exceptions:

Type Description
AuthorizationException

If the external user could not be authorized.

Source code in zenml/zen_server/auth.py
def authenticate_external_user(external_access_token: str) -> AuthContext:
    """Implement external authentication.

    Args:
        external_access_token: The access token used to authenticate the user
            to the external authenticator.

    Returns:
        The authentication context reflecting the authenticated user.

    Raises:
        AuthorizationException: If the external user could not be authorized.
    """
    config = server_config()
    store = zen_store()

    assert config.external_user_info_url is not None

    # Use the external access token to extract the user information and
    # permissions

    # Get the user information from the external authenticator
    user_info_url = config.external_user_info_url
    headers = {"Authorization": "Bearer " + external_access_token}
    query_params = dict(server_id=str(config.get_external_server_id()))

    try:
        auth_response = requests.get(
            user_info_url,
            headers=headers,
            params=urlencode(query_params),
            timeout=EXTERNAL_AUTHENTICATOR_TIMEOUT,
        )
    except Exception as e:
        logger.exception(
            f"Error fetching user information from external authenticator: "
            f"{e}"
        )
        raise AuthorizationException(
            "Error fetching user information from external authenticator."
        )

    external_user: Optional[ExternalUserModel] = None

    if 200 <= auth_response.status_code < 300:
        try:
            payload = auth_response.json()
        except requests.exceptions.JSONDecodeError:
            logger.exception(
                "Error decoding JSON response from external authenticator."
            )
            raise AuthorizationException(
                "Unknown external authenticator error"
            )

        if isinstance(payload, dict):
            try:
                external_user = ExternalUserModel.model_validate(payload)
            except Exception as e:
                logger.exception(
                    f"Error parsing user information from external "
                    f"authenticator: {e}"
                )
                pass

    elif auth_response.status_code in [401, 403]:
        raise AuthorizationException("Not authorized to access this server.")
    elif auth_response.status_code == 404:
        raise AuthorizationException(
            "External authenticator did not recognize this server."
        )
    else:
        logger.error(
            f"Error fetching user information from external authenticator. "
            f"Status code: {auth_response.status_code}, "
            f"Response: {auth_response.text}"
        )
        raise AuthorizationException(
            "Error fetching user information from external authenticator. "
        )

    if not external_user:
        raise AuthorizationException("Unknown external authenticator error")

    # With an external user object, we can now authenticate the user against
    # the ZenML server

    # Check if the external user already exists in the ZenML server database
    # If not, create a new user. If yes, update the existing user.
    try:
        user = store.get_external_user(user_id=external_user.id)

        # Update the user information
        user = store.update_user(
            user_id=user.id,
            user_update=UserUpdate(
                name=external_user.email,
                full_name=external_user.name or "",
                email_opted_in=True,
                active=True,
                email=external_user.email,
                is_admin=external_user.is_admin,
            ),
        )
    except KeyError:
        logger.info(
            f"External user with ID {external_user.id} not found in ZenML "
            f"server database. Creating a new user."
        )
        user = store.create_user(
            UserRequest(
                name=external_user.email,
                full_name=external_user.name or "",
                external_user_id=external_user.id,
                email_opted_in=True,
                active=True,
                email=external_user.email,
                is_admin=external_user.is_admin,
            )
        )

        with AnalyticsContext() as context:
            context.user_id = user.id
            context.identify(
                traits={
                    "email": external_user.email,
                    "source": "external_auth",
                }
            )
            context.alias(user_id=external_user.id, previous_id=user.id)

    return AuthContext(user=user)

authentication_provider()

Returns the authentication provider.

Returns:

Type Description
Callable[..., zenml.zen_server.auth.AuthContext]

The authentication provider.

Exceptions:

Type Description
ValueError

If the authentication scheme is not supported.

Source code in zenml/zen_server/auth.py
def authentication_provider() -> Callable[..., AuthContext]:
    """Returns the authentication provider.

    Returns:
        The authentication provider.

    Raises:
        ValueError: If the authentication scheme is not supported.
    """
    auth_scheme = server_config().auth_scheme
    if auth_scheme == AuthScheme.NO_AUTH:
        return no_authentication
    elif auth_scheme == AuthScheme.HTTP_BASIC:
        return http_authentication
    elif auth_scheme == AuthScheme.OAUTH2_PASSWORD_BEARER:
        return oauth2_authentication
    elif auth_scheme == AuthScheme.EXTERNAL:
        return oauth2_authentication
    else:
        raise ValueError(f"Unknown authentication scheme: {auth_scheme}")

authorize(token=Depends(CookieOAuth2TokenBearer))

Authenticates any request to the ZenML server with OAuth2 JWT tokens.

Parameters:

Name Type Description Default
token str

The JWT bearer token to be authenticated.

Depends(CookieOAuth2TokenBearer)

Returns:

Type Description
AuthContext

The authentication context reflecting the authenticated user.

Exceptions:

Type Description
HTTPException

If the JWT token could not be authorized.

Source code in zenml/zen_server/auth.py
def oauth2_authentication(
    token: str = Depends(
        CookieOAuth2TokenBearer(
            tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
        )
    ),
) -> AuthContext:
    """Authenticates any request to the ZenML server with OAuth2 JWT tokens.

    Args:
        token: The JWT bearer token to be authenticated.

    Returns:
        The authentication context reflecting the authenticated user.

    Raises:
        HTTPException: If the JWT token could not be authorized.
    """
    try:
        auth_context = authenticate_credentials(access_token=token)
    except AuthorizationException as e:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=str(e),
            headers={"WWW-Authenticate": "Bearer"},
        )

    return auth_context

get_auth_context()

Returns the current authentication context.

Returns:

Type Description
Optional[AuthContext]

The authentication context.

Source code in zenml/zen_server/auth.py
def get_auth_context() -> Optional["AuthContext"]:
    """Returns the current authentication context.

    Returns:
        The authentication context.
    """
    auth_context = _auth_context.get()
    return auth_context

http_authentication(credentials=Depends(HTTPBasic))

Authenticates any request to the ZenML Server with basic HTTP authentication.

Parameters:

Name Type Description Default
credentials HTTPBasicCredentials

HTTP basic auth credentials passed to the request.

Depends(HTTPBasic)

Returns:

Type Description
AuthContext

The authentication context reflecting the authenticated user.

Exceptions:

Type Description
HTTPException

If the credentials are invalid.

Source code in zenml/zen_server/auth.py
def http_authentication(
    credentials: HTTPBasicCredentials = Depends(HTTPBasic()),
) -> AuthContext:
    """Authenticates any request to the ZenML Server with basic HTTP authentication.

    Args:
        credentials: HTTP basic auth credentials passed to the request.

    Returns:
        The authentication context reflecting the authenticated user.

    Raises:
        HTTPException: If the credentials are invalid.
    """
    try:
        return authenticate_credentials(
            user_name_or_id=credentials.username, password=credentials.password
        )
    except AuthorizationException as e:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=str(e),
            headers={"WWW-Authenticate": "Basic"},
        )

no_authentication()

Doesn't authenticate requests to the ZenML server.

Returns:

Type Description
AuthContext

The authentication context reflecting the default user.

Source code in zenml/zen_server/auth.py
def no_authentication() -> AuthContext:
    """Doesn't authenticate requests to the ZenML server.

    Returns:
        The authentication context reflecting the default user.
    """
    return authenticate_credentials(user_name_or_id=DEFAULT_USERNAME)

oauth2_authentication(token=Depends(CookieOAuth2TokenBearer))

Authenticates any request to the ZenML server with OAuth2 JWT tokens.

Parameters:

Name Type Description Default
token str

The JWT bearer token to be authenticated.

Depends(CookieOAuth2TokenBearer)

Returns:

Type Description
AuthContext

The authentication context reflecting the authenticated user.

Exceptions:

Type Description
HTTPException

If the JWT token could not be authorized.

Source code in zenml/zen_server/auth.py
def oauth2_authentication(
    token: str = Depends(
        CookieOAuth2TokenBearer(
            tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
        )
    ),
) -> AuthContext:
    """Authenticates any request to the ZenML server with OAuth2 JWT tokens.

    Args:
        token: The JWT bearer token to be authenticated.

    Returns:
        The authentication context reflecting the authenticated user.

    Raises:
        HTTPException: If the JWT token could not be authorized.
    """
    try:
        auth_context = authenticate_credentials(access_token=token)
    except AuthorizationException as e:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=str(e),
            headers={"WWW-Authenticate": "Bearer"},
        )

    return auth_context

set_auth_context(auth_context)

Sets the current authentication context.

Parameters:

Name Type Description Default
auth_context AuthContext

The authentication context.

required

Returns:

Type Description
AuthContext

The authentication context.

Source code in zenml/zen_server/auth.py
def set_auth_context(auth_context: "AuthContext") -> "AuthContext":
    """Sets the current authentication context.

    Args:
        auth_context: The authentication context.

    Returns:
        The authentication context.
    """
    _auth_context.set(auth_context)
    return auth_context

cloud_utils

Utils concerning anything concerning the cloud control plane backend.

ZenMLCloudConfiguration (BaseModel)

ZenML Pro RBAC configuration.

Source code in zenml/zen_server/cloud_utils.py
class ZenMLCloudConfiguration(BaseModel):
    """ZenML Pro RBAC configuration."""

    api_url: str

    oauth2_client_id: str
    oauth2_client_secret: str
    oauth2_audience: str
    auth0_domain: str

    @field_validator("api_url")
    @classmethod
    def _strip_trailing_slashes_url(cls, url: str) -> str:
        """Strip any trailing slashes on the API URL.

        Args:
            url: The API URL.

        Returns:
            The API URL with potential trailing slashes removed.
        """
        return url.rstrip("/")

    @classmethod
    def from_environment(cls) -> "ZenMLCloudConfiguration":
        """Get the RBAC configuration from environment variables.

        Returns:
            The RBAC configuration.
        """
        env_config: Dict[str, Any] = {}
        for k, v in os.environ.items():
            if v == "":
                continue
            if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
                env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v

        return ZenMLCloudConfiguration(**env_config)

    model_config = ConfigDict(
        # Allow extra attributes from configs of previous ZenML versions to
        # permit downgrading
        extra="allow"
    )
from_environment() classmethod

Get the RBAC configuration from environment variables.

Returns:

Type Description
ZenMLCloudConfiguration

The RBAC configuration.

Source code in zenml/zen_server/cloud_utils.py
@classmethod
def from_environment(cls) -> "ZenMLCloudConfiguration":
    """Get the RBAC configuration from environment variables.

    Returns:
        The RBAC configuration.
    """
    env_config: Dict[str, Any] = {}
    for k, v in os.environ.items():
        if v == "":
            continue
        if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
            env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v

    return ZenMLCloudConfiguration(**env_config)

ZenMLCloudConnection

Class to use for communication between server and control plane.

Source code in zenml/zen_server/cloud_utils.py
class ZenMLCloudConnection:
    """Class to use for communication between server and control plane."""

    def __init__(self) -> None:
        """Initialize the RBAC component."""
        self._config = ZenMLCloudConfiguration.from_environment()
        self._session: Optional[requests.Session] = None

    def get(
        self, endpoint: str, params: Optional[Dict[str, Any]]
    ) -> requests.Response:
        """Send a GET request using the active session.

        Args:
            endpoint: The endpoint to send the request to. This will be appended
                to the base URL.
            params: Parameters to include in the request.

        Raises:
            RuntimeError: If the request failed.
            SubscriptionUpgradeRequiredError: In case the current subscription
                tier is insufficient for the attempted operation.

        Returns:
            The response.
        """
        url = self._config.api_url + endpoint

        response = self.session.get(url=url, params=params, timeout=7)
        if response.status_code == 401:
            # Refresh the auth token and try again
            self._clear_session()
            response = self.session.get(url=url, params=params, timeout=7)

        try:
            response.raise_for_status()
        except requests.HTTPError:
            if response.status_code == 402:
                raise SubscriptionUpgradeRequiredError(response.json())
            else:
                raise RuntimeError(
                    f"Failed with the following error {response} {response.text}"
                )

        return response

    def post(
        self,
        endpoint: str,
        params: Optional[Dict[str, Any]] = None,
        data: Optional[Dict[str, Any]] = None,
    ) -> requests.Response:
        """Send a POST request using the active session.

        Args:
            endpoint: The endpoint to send the request to. This will be appended
                to the base URL.
            params: Parameters to include in the request.
            data: Data to include in the request.

        Raises:
            RuntimeError: If the request failed.

        Returns:
            The response.
        """
        url = self._config.api_url + endpoint

        response = self.session.post(
            url=url, params=params, json=data, timeout=7
        )
        if response.status_code == 401:
            # Refresh the auth token and try again
            self._clear_session()
            response = self.session.post(
                url=url, params=params, json=data, timeout=7
            )

        try:
            response.raise_for_status()
        except requests.HTTPError as e:
            raise RuntimeError(
                f"Failed while trying to contact the central zenml pro "
                f"service: {e}"
            )

        return response

    @property
    def session(self) -> requests.Session:
        """Authenticate to the ZenML Pro Management Plane.

        Returns:
            A requests session with the authentication token.
        """
        if self._session is None:
            # Set up the session's connection pool size to match the server's
            # thread pool size. This allows the server to cache one connection
            # per thread, which means we can keep connections open for longer
            # and avoid the overhead of setting up a new connection for each
            # request.
            conn_pool_size = server_config().thread_pool_size

            self._session = requests.Session()
            token = self._fetch_auth_token()
            self._session.headers.update({"Authorization": "Bearer " + token})

            retries = Retry(total=5, backoff_factor=0.1)
            self._session.mount(
                "https://",
                HTTPAdapter(
                    max_retries=retries,
                    # We only use one connection pool to be cached because we
                    # only communicate with one remote server (the control
                    # plane)
                    pool_connections=1,
                    pool_maxsize=conn_pool_size,
                ),
            )

        return self._session

    def _clear_session(self) -> None:
        """Clear the authentication session."""
        self._session = None

    def _fetch_auth_token(self) -> str:
        """Fetch an auth token for the Cloud API from auth0.

        Raises:
            RuntimeError: If the auth token can't be fetched.

        Returns:
            Auth token.
        """
        # Get an auth token from auth0
        auth0_url = f"https://{self._config.auth0_domain}/oauth/token"
        headers = {"content-type": "application/x-www-form-urlencoded"}
        payload = {
            "client_id": self._config.oauth2_client_id,
            "client_secret": self._config.oauth2_client_secret,
            "audience": self._config.oauth2_audience,
            "grant_type": "client_credentials",
        }
        try:
            response = requests.post(
                auth0_url, headers=headers, data=payload, timeout=7
            )
            response.raise_for_status()
        except Exception as e:
            raise RuntimeError(f"Error fetching auth token from auth0: {e}")

        access_token = response.json().get("access_token", "")

        if not access_token or not isinstance(access_token, str):
            raise RuntimeError("Could not fetch auth token from auth0.")

        return str(access_token)
session: Session property readonly

Authenticate to the ZenML Pro Management Plane.

Returns:

Type Description
Session

A requests session with the authentication token.

__init__(self) special

Initialize the RBAC component.

Source code in zenml/zen_server/cloud_utils.py
def __init__(self) -> None:
    """Initialize the RBAC component."""
    self._config = ZenMLCloudConfiguration.from_environment()
    self._session: Optional[requests.Session] = None
get(self, endpoint, params)

Send a GET request using the active session.

Parameters:

Name Type Description Default
endpoint str

The endpoint to send the request to. This will be appended to the base URL.

required
params Optional[Dict[str, Any]]

Parameters to include in the request.

required

Exceptions:

Type Description
RuntimeError

If the request failed.

SubscriptionUpgradeRequiredError

In case the current subscription tier is insufficient for the attempted operation.

Returns:

Type Description
Response

The response.

Source code in zenml/zen_server/cloud_utils.py
def get(
    self, endpoint: str, params: Optional[Dict[str, Any]]
) -> requests.Response:
    """Send a GET request using the active session.

    Args:
        endpoint: The endpoint to send the request to. This will be appended
            to the base URL.
        params: Parameters to include in the request.

    Raises:
        RuntimeError: If the request failed.
        SubscriptionUpgradeRequiredError: In case the current subscription
            tier is insufficient for the attempted operation.

    Returns:
        The response.
    """
    url = self._config.api_url + endpoint

    response = self.session.get(url=url, params=params, timeout=7)
    if response.status_code == 401:
        # Refresh the auth token and try again
        self._clear_session()
        response = self.session.get(url=url, params=params, timeout=7)

    try:
        response.raise_for_status()
    except requests.HTTPError:
        if response.status_code == 402:
            raise SubscriptionUpgradeRequiredError(response.json())
        else:
            raise RuntimeError(
                f"Failed with the following error {response} {response.text}"
            )

    return response
post(self, endpoint, params=None, data=None)

Send a POST request using the active session.

Parameters:

Name Type Description Default
endpoint str

The endpoint to send the request to. This will be appended to the base URL.

required
params Optional[Dict[str, Any]]

Parameters to include in the request.

None
data Optional[Dict[str, Any]]

Data to include in the request.

None

Exceptions:

Type Description
RuntimeError

If the request failed.

Returns:

Type Description
Response

The response.

Source code in zenml/zen_server/cloud_utils.py
def post(
    self,
    endpoint: str,
    params: Optional[Dict[str, Any]] = None,
    data: Optional[Dict[str, Any]] = None,
) -> requests.Response:
    """Send a POST request using the active session.

    Args:
        endpoint: The endpoint to send the request to. This will be appended
            to the base URL.
        params: Parameters to include in the request.
        data: Data to include in the request.

    Raises:
        RuntimeError: If the request failed.

    Returns:
        The response.
    """
    url = self._config.api_url + endpoint

    response = self.session.post(
        url=url, params=params, json=data, timeout=7
    )
    if response.status_code == 401:
        # Refresh the auth token and try again
        self._clear_session()
        response = self.session.post(
            url=url, params=params, json=data, timeout=7
        )

    try:
        response.raise_for_status()
    except requests.HTTPError as e:
        raise RuntimeError(
            f"Failed while trying to contact the central zenml pro "
            f"service: {e}"
        )

    return response

cloud_connection()

Return the initialized cloud connection.

Returns:

Type Description
ZenMLCloudConnection

The cloud connection.

Source code in zenml/zen_server/cloud_utils.py
def cloud_connection() -> ZenMLCloudConnection:
    """Return the initialized cloud connection.

    Returns:
        The cloud connection.
    """
    global _cloud_connection
    if _cloud_connection is None:
        _cloud_connection = ZenMLCloudConnection()

    return _cloud_connection

deploy special

ZenML server deployments.

base_provider

Base ZenML server provider class.

BaseServerProvider (ABC)

Base ZenML server provider class.

All ZenML server providers must extend and implement this base class.

Source code in zenml/zen_server/deploy/base_provider.py
class BaseServerProvider(ABC):
    """Base ZenML server provider class.

    All ZenML server providers must extend and implement this base class.
    """

    TYPE: ClassVar[ServerProviderType]
    CONFIG_TYPE: ClassVar[Type[ServerDeploymentConfig]] = (
        ServerDeploymentConfig
    )

    @classmethod
    def register_as_provider(cls) -> None:
        """Register the class as a server provider."""
        from zenml.zen_server.deploy.deployer import ServerDeployer

        ServerDeployer.register_provider(cls)

    @classmethod
    def _convert_config(
        cls, config: ServerDeploymentConfig
    ) -> ServerDeploymentConfig:
        """Convert a generic server deployment config into a provider specific config.

        Args:
            config: The generic server deployment config.

        Returns:
            The provider specific server deployment config.

        Raises:
            ServerDeploymentConfigurationError: If the configuration is not
                valid.
        """
        if isinstance(config, cls.CONFIG_TYPE):
            return config
        try:
            return cls.CONFIG_TYPE(**config.model_dump())
        except ValidationError as e:
            raise ServerDeploymentConfigurationError(
                f"Invalid configuration for provider {cls.TYPE.value}: {e}"
            )

    def deploy_server(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> ServerDeployment:
        """Deploy a new ZenML server.

        Args:
            config: The generic server deployment configuration.
            timeout: The timeout in seconds to wait until the deployment is
                successful. If not supplied, the default timeout value specified
                by the provider is used.

        Returns:
            The newly created server deployment.

        Raises:
            ServerDeploymentExistsError: If a deployment with the same name
                already exists.
        """
        try:
            self._get_service(config.name)
        except KeyError:
            pass
        else:
            raise ServerDeploymentExistsError(
                f"ZenML server deployment with name '{config.name}' already "
                f"exists"
            )

        # convert the generic deployment config to a provider specific
        # deployment config
        config = self._convert_config(config)
        service = self._create_service(config, timeout)
        return self._get_deployment(service)

    def update_server(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> ServerDeployment:
        """Update an existing ZenML server deployment.

        Args:
            config: The new generic server deployment configuration.
            timeout: The timeout in seconds to wait until the update is
                successful. If not supplied, the default timeout value specified
                by the provider is used.

        Returns:
            The updated server deployment.

        Raises:
            ServerDeploymentNotFoundError: If a deployment with the given name
                doesn't exist.
        """
        try:
            service = self._get_service(config.name)
        except KeyError:
            raise ServerDeploymentNotFoundError(
                f"ZenML server deployment with name '{config.name}' was not "
                f"found"
            )

        # convert the generic deployment config to a provider specific
        # deployment config
        config = self._convert_config(config)
        old_config = self._get_deployment_config(service)

        if old_config == config:
            logger.info(
                f"The {config.name} ZenML server is already configured with "
                f"the same parameters."
            )
            service = self._start_service(service, timeout)
        else:
            logger.info(f"Updating the {config.name} ZenML server.")
            service = self._update_service(service, config, timeout)

        return self._get_deployment(service)

    def remove_server(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> None:
        """Tears down and removes all resources and files associated with a ZenML server deployment.

        Args:
            config: The generic server deployment configuration.
            timeout: The timeout in seconds to wait until the server is
                removed. If not supplied, the default timeout value specified
                by the provider is used.

        Raises:
            ServerDeploymentNotFoundError: If a deployment with the given name
                doesn't exist.
        """
        try:
            service = self._get_service(config.name)
        except KeyError:
            raise ServerDeploymentNotFoundError(
                f"ZenML server deployment with name '{config.name}' was not "
                f"found"
            )

        logger.info(f"Removing the {config.name} ZenML server.")
        self._delete_service(service, timeout)

    def get_server(
        self,
        config: ServerDeploymentConfig,
    ) -> ServerDeployment:
        """Retrieve information about a ZenML server deployment.

        Args:
            config: The generic server deployment configuration.

        Returns:
            The server deployment.

        Raises:
            ServerDeploymentNotFoundError: If a deployment with the given name
                doesn't exist.
        """
        try:
            service = self._get_service(config.name)
        except KeyError:
            raise ServerDeploymentNotFoundError(
                f"ZenML server deployment with name '{config.name}' was not "
                f"found"
            )

        return self._get_deployment(service)

    def list_servers(self) -> List[ServerDeployment]:
        """List all server deployments managed by this provider.

        Returns:
            The list of server deployments.
        """
        return [
            self._get_deployment(service) for service in self._list_services()
        ]

    def get_server_logs(
        self,
        config: ServerDeploymentConfig,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Retrieve the logs of a ZenML server.

        Args:
            config: The generic server deployment configuration.
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.

        Raises:
            ServerDeploymentNotFoundError: If a deployment with the given name
                doesn't exist.
        """
        try:
            service = self._get_service(config.name)
        except KeyError:
            raise ServerDeploymentNotFoundError(
                f"ZenML server deployment with name '{config.name}' was not "
                f"found"
            )

        return service.get_logs(follow=follow, tail=tail)

    def _get_deployment_status(
        self, service: BaseService
    ) -> ServerDeploymentStatus:
        """Get the status of a server deployment from its service.

        Args:
            service: The server deployment service.

        Returns:
            The status of the server deployment.
        """
        gc = GlobalConfiguration()
        url: Optional[str] = None
        if service.is_running:
            # all services must have an endpoint
            assert service.endpoint is not None

            url = service.endpoint.status.uri
        connected = url is not None and gc.store_configuration.url == url

        return ServerDeploymentStatus(
            url=url,
            status=service.status.state,
            status_message=service.status.last_error,
            connected=connected,
        )

    def _get_deployment(self, service: BaseService) -> ServerDeployment:
        """Get the server deployment associated with a service.

        Args:
            service: The service.

        Returns:
            The server deployment.
        """
        config = self._get_deployment_config(service)

        return ServerDeployment(
            config=config,
            status=self._get_deployment_status(service),
        )

    @classmethod
    @abstractmethod
    def _get_service_configuration(
        cls,
        server_config: ServerDeploymentConfig,
    ) -> Tuple[
        ServiceConfig,
        ServiceEndpointConfig,
        ServiceEndpointHealthMonitorConfig,
    ]:
        """Construct the service configuration from a server deployment configuration.

        Args:
            server_config: server deployment configuration.

        Returns:
            The service, service endpoint and endpoint monitor configuration.
        """

    @abstractmethod
    def _create_service(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Create, start and return a service instance for a ZenML server deployment.

        Args:
            config: The server deployment configuration.
            timeout: The timeout in seconds to wait until the service is
                running. If not supplied, a default timeout value specified
                by the provider implementation should be used.

        Returns:
            The service instance.
        """

    @abstractmethod
    def _update_service(
        self,
        service: BaseService,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Update an existing service instance for a ZenML server deployment.

        Args:
            service: The service instance.
            config: The new server deployment configuration.
            timeout: The timeout in seconds to wait until the updated service is
                running. If not supplied, a default timeout value specified
                by the provider implementation should be used.

        Returns:
            The updated service instance.
        """

    @abstractmethod
    def _start_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Start a service instance for a ZenML server deployment.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                running. If not supplied, a default timeout value specified
                by the provider implementation should be used.

        Returns:
            The updated service instance.
        """

    @abstractmethod
    def _stop_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Stop a service instance for a ZenML server deployment.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                stopped. If not supplied, a default timeout value specified
                by the provider implementation should be used.

        Returns:
            The updated service instance.
        """

    @abstractmethod
    def _delete_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> None:
        """Remove a service instance for a ZenML server deployment.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                removed. If not supplied, a default timeout value specified
                by the provider implementation should be used.
        """

    @abstractmethod
    def _get_service(self, server_name: str) -> BaseService:
        """Get the service instance associated with a ZenML server deployment.

        Args:
            server_name: The server deployment name.

        Returns:
            The service instance.

        Raises:
            KeyError: If the server deployment is not found.
        """

    @abstractmethod
    def _list_services(self) -> List[BaseService]:
        """Get all service instances for all deployed ZenML servers.

        Returns:
            A list of service instances.
        """

    @abstractmethod
    def _get_deployment_config(
        self, service: BaseService
    ) -> ServerDeploymentConfig:
        """Recreate the server deployment config from a service instance.

        Args:
            service: The service instance.

        Returns:
            The server deployment config.
        """
CONFIG_TYPE (BaseModel)

Generic server deployment configuration.

All server deployment configurations should inherit from this class and handle extra attributes as provider specific attributes.

Attributes:

Name Type Description
name str

Name of the server deployment.

provider ServerProviderType

The server provider type.

Source code in zenml/zen_server/deploy/base_provider.py
class ServerDeploymentConfig(BaseModel):
    """Generic server deployment configuration.

    All server deployment configurations should inherit from this class and
    handle extra attributes as provider specific attributes.

    Attributes:
        name: Name of the server deployment.
        provider: The server provider type.
    """

    name: str
    provider: ServerProviderType
    model_config = ConfigDict(
        # Validate attributes when assigning them. We need to set this in order
        # to have a mix of mutable and immutable attributes
        validate_assignment=True,
        # Allow extra attributes to be set in the base class. The concrete
        # classes are responsible for validating the attributes.
        extra="allow",
    )
deploy_server(self, config, timeout=None)

Deploy a new ZenML server.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The generic server deployment configuration.

required
timeout Optional[int]

The timeout in seconds to wait until the deployment is successful. If not supplied, the default timeout value specified by the provider is used.

None

Returns:

Type Description
ServerDeployment

The newly created server deployment.

Exceptions:

Type Description
ServerDeploymentExistsError

If a deployment with the same name already exists.

Source code in zenml/zen_server/deploy/base_provider.py
def deploy_server(
    self,
    config: ServerDeploymentConfig,
    timeout: Optional[int] = None,
) -> ServerDeployment:
    """Deploy a new ZenML server.

    Args:
        config: The generic server deployment configuration.
        timeout: The timeout in seconds to wait until the deployment is
            successful. If not supplied, the default timeout value specified
            by the provider is used.

    Returns:
        The newly created server deployment.

    Raises:
        ServerDeploymentExistsError: If a deployment with the same name
            already exists.
    """
    try:
        self._get_service(config.name)
    except KeyError:
        pass
    else:
        raise ServerDeploymentExistsError(
            f"ZenML server deployment with name '{config.name}' already "
            f"exists"
        )

    # convert the generic deployment config to a provider specific
    # deployment config
    config = self._convert_config(config)
    service = self._create_service(config, timeout)
    return self._get_deployment(service)
get_server(self, config)

Retrieve information about a ZenML server deployment.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The generic server deployment configuration.

required

Returns:

Type Description
ServerDeployment

The server deployment.

Exceptions:

Type Description
ServerDeploymentNotFoundError

If a deployment with the given name doesn't exist.

Source code in zenml/zen_server/deploy/base_provider.py
def get_server(
    self,
    config: ServerDeploymentConfig,
) -> ServerDeployment:
    """Retrieve information about a ZenML server deployment.

    Args:
        config: The generic server deployment configuration.

    Returns:
        The server deployment.

    Raises:
        ServerDeploymentNotFoundError: If a deployment with the given name
            doesn't exist.
    """
    try:
        service = self._get_service(config.name)
    except KeyError:
        raise ServerDeploymentNotFoundError(
            f"ZenML server deployment with name '{config.name}' was not "
            f"found"
        )

    return self._get_deployment(service)
get_server_logs(self, config, follow=False, tail=None)

Retrieve the logs of a ZenML server.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The generic server deployment configuration.

required
follow bool

if True, the logs will be streamed as they are written

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Exceptions:

Type Description
ServerDeploymentNotFoundError

If a deployment with the given name doesn't exist.

Source code in zenml/zen_server/deploy/base_provider.py
def get_server_logs(
    self,
    config: ServerDeploymentConfig,
    follow: bool = False,
    tail: Optional[int] = None,
) -> Generator[str, bool, None]:
    """Retrieve the logs of a ZenML server.

    Args:
        config: The generic server deployment configuration.
        follow: if True, the logs will be streamed as they are written
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.

    Raises:
        ServerDeploymentNotFoundError: If a deployment with the given name
            doesn't exist.
    """
    try:
        service = self._get_service(config.name)
    except KeyError:
        raise ServerDeploymentNotFoundError(
            f"ZenML server deployment with name '{config.name}' was not "
            f"found"
        )

    return service.get_logs(follow=follow, tail=tail)
list_servers(self)

List all server deployments managed by this provider.

Returns:

Type Description
List[zenml.zen_server.deploy.deployment.ServerDeployment]

The list of server deployments.

Source code in zenml/zen_server/deploy/base_provider.py
def list_servers(self) -> List[ServerDeployment]:
    """List all server deployments managed by this provider.

    Returns:
        The list of server deployments.
    """
    return [
        self._get_deployment(service) for service in self._list_services()
    ]
register_as_provider() classmethod

Register the class as a server provider.

Source code in zenml/zen_server/deploy/base_provider.py
@classmethod
def register_as_provider(cls) -> None:
    """Register the class as a server provider."""
    from zenml.zen_server.deploy.deployer import ServerDeployer

    ServerDeployer.register_provider(cls)
remove_server(self, config, timeout=None)

Tears down and removes all resources and files associated with a ZenML server deployment.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The generic server deployment configuration.

required
timeout Optional[int]

The timeout in seconds to wait until the server is removed. If not supplied, the default timeout value specified by the provider is used.

None

Exceptions:

Type Description
ServerDeploymentNotFoundError

If a deployment with the given name doesn't exist.

Source code in zenml/zen_server/deploy/base_provider.py
def remove_server(
    self,
    config: ServerDeploymentConfig,
    timeout: Optional[int] = None,
) -> None:
    """Tears down and removes all resources and files associated with a ZenML server deployment.

    Args:
        config: The generic server deployment configuration.
        timeout: The timeout in seconds to wait until the server is
            removed. If not supplied, the default timeout value specified
            by the provider is used.

    Raises:
        ServerDeploymentNotFoundError: If a deployment with the given name
            doesn't exist.
    """
    try:
        service = self._get_service(config.name)
    except KeyError:
        raise ServerDeploymentNotFoundError(
            f"ZenML server deployment with name '{config.name}' was not "
            f"found"
        )

    logger.info(f"Removing the {config.name} ZenML server.")
    self._delete_service(service, timeout)
update_server(self, config, timeout=None)

Update an existing ZenML server deployment.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The new generic server deployment configuration.

required
timeout Optional[int]

The timeout in seconds to wait until the update is successful. If not supplied, the default timeout value specified by the provider is used.

None

Returns:

Type Description
ServerDeployment

The updated server deployment.

Exceptions:

Type Description
ServerDeploymentNotFoundError

If a deployment with the given name doesn't exist.

Source code in zenml/zen_server/deploy/base_provider.py
def update_server(
    self,
    config: ServerDeploymentConfig,
    timeout: Optional[int] = None,
) -> ServerDeployment:
    """Update an existing ZenML server deployment.

    Args:
        config: The new generic server deployment configuration.
        timeout: The timeout in seconds to wait until the update is
            successful. If not supplied, the default timeout value specified
            by the provider is used.

    Returns:
        The updated server deployment.

    Raises:
        ServerDeploymentNotFoundError: If a deployment with the given name
            doesn't exist.
    """
    try:
        service = self._get_service(config.name)
    except KeyError:
        raise ServerDeploymentNotFoundError(
            f"ZenML server deployment with name '{config.name}' was not "
            f"found"
        )

    # convert the generic deployment config to a provider specific
    # deployment config
    config = self._convert_config(config)
    old_config = self._get_deployment_config(service)

    if old_config == config:
        logger.info(
            f"The {config.name} ZenML server is already configured with "
            f"the same parameters."
        )
        service = self._start_service(service, timeout)
    else:
        logger.info(f"Updating the {config.name} ZenML server.")
        service = self._update_service(service, config, timeout)

    return self._get_deployment(service)

deployer

ZenML server deployer singleton implementation.

ServerDeployer

Server deployer singleton.

This class is responsible for managing the various server provider implementations and for directing server deployment lifecycle requests to the responsible provider. It acts as a facade built on top of the various server providers.

Source code in zenml/zen_server/deploy/deployer.py
class ServerDeployer(metaclass=SingletonMetaClass):
    """Server deployer singleton.

    This class is responsible for managing the various server provider
    implementations and for directing server deployment lifecycle requests to
    the responsible provider. It acts as a facade built on top of the various
    server providers.
    """

    _providers: ClassVar[Dict[ServerProviderType, BaseServerProvider]] = {}

    @classmethod
    def register_provider(cls, provider: Type[BaseServerProvider]) -> None:
        """Register a server provider.

        Args:
            provider: The server provider to register.

        Raises:
            TypeError: If a provider with the same type is already registered.
        """
        if provider.TYPE in cls._providers:
            raise TypeError(
                f"Server provider '{provider.TYPE}' is already registered."
            )
        logger.debug(f"Registering server provider '{provider.TYPE}'.")
        cls._providers[provider.TYPE] = provider()

    @classmethod
    def get_provider(
        cls, provider_type: ServerProviderType
    ) -> BaseServerProvider:
        """Get the server provider associated with a provider type.

        Args:
            provider_type: The server provider type.

        Returns:
            The server provider associated with the provider type.

        Raises:
            ServerProviderNotFoundError: If no provider is registered for the
                given provider type.
        """
        if provider_type not in cls._providers:
            raise ServerProviderNotFoundError(
                f"Server provider '{provider_type}' is not registered."
            )
        return cls._providers[provider_type]

    def deploy_server(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> ServerDeployment:
        """Deploy a new ZenML server or update an existing deployment.

        Args:
            config: The server deployment configuration.
            timeout: The timeout in seconds to wait until the deployment is
                successful. If not supplied, the default timeout value specified
                by the provider is used.

        Returns:
            The server deployment.
        """
        # We do this here to ensure that the zenml store is always initialized
        # before the server is deployed. This is necessary because the server
        # may require access to the local store configuration or database.
        gc = GlobalConfiguration()

        _ = gc.zen_store

        try:
            self.get_server(config.name)
        except ServerDeploymentNotFoundError:
            pass
        else:
            return self.update_server(config=config, timeout=timeout)

        provider_name = config.provider.value
        provider = self.get_provider(config.provider)

        logger.info(
            f"Deploying a {provider_name} ZenML server with name "
            f"'{config.name}'."
        )
        return provider.deploy_server(config, timeout=timeout)

    def update_server(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> ServerDeployment:
        """Update an existing ZenML server deployment.

        Args:
            config: The new server deployment configuration.
            timeout: The timeout in seconds to wait until the deployment is
                successful. If not supplied, a default timeout value of 30
                seconds is used.

        Returns:
            The updated server deployment.

        Raises:
            ServerDeploymentExistsError: If an existing deployment with the same
                name but a different provider type is found.
        """
        # this will also raise ServerDeploymentNotFoundError if the server
        # does not exist
        existing_server = self.get_server(config.name)

        provider = self.get_provider(config.provider)
        existing_provider = existing_server.config.provider

        if existing_provider != config.provider:
            raise ServerDeploymentExistsError(
                f"A server deployment with the same name '{config.name}' but "
                f"with a different provider '{existing_provider.value}'."
                f"is already provisioned. Please choose a different name or "
                f"tear down the existing deployment."
            )

        return provider.update_server(config, timeout=timeout)

    def remove_server(
        self,
        server_name: str,
        timeout: Optional[int] = None,
    ) -> None:
        """Tears down and removes all resources and files associated with a ZenML server deployment.

        Args:
            server_name: The server deployment name.
            timeout: The timeout in seconds to wait until the deployment is
                successfully torn down. If not supplied, a provider specific
                default timeout value is used.
        """
        # this will also raise ServerDeploymentNotFoundError if the server
        # does not exist
        server = self.get_server(server_name)

        provider_name = server.config.provider.value
        provider = self.get_provider(server.config.provider)

        if self.is_connected_to_server(server_name):
            try:
                self.disconnect_from_server(server_name)
            except Exception as e:
                logger.warning(f"Failed to disconnect from the server: {e}")

        logger.info(
            f"Tearing down the '{server_name}' {provider_name} ZenML server."
        )
        provider.remove_server(server.config, timeout=timeout)

    def is_connected_to_server(self, server_name: str) -> bool:
        """Check if the ZenML client is currently connected to a ZenML server.

        Args:
            server_name: The server deployment name.

        Returns:
            True if the ZenML client is connected to the ZenML server, False
            otherwise.
        """
        # this will also raise ServerDeploymentNotFoundError if the server
        # does not exist
        server = self.get_server(server_name)

        gc = GlobalConfiguration()
        return (
            server.status is not None
            and server.status.url is not None
            and gc.store_configuration.url == server.status.url
        )

    def connect_to_server(
        self,
        server_name: str,
        username: str,
        password: str,
        verify_ssl: Union[bool, str] = True,
    ) -> None:
        """Connect to a ZenML server instance.

        Args:
            server_name: The server deployment name.
            username: The username to use to connect to the server.
            password: The password to use to connect to the server.
            verify_ssl: Either a boolean, in which case it controls whether we
                verify the server's TLS certificate, or a string, in which case
                it must be a path to a CA bundle to use or the CA bundle value
                itself.

        Raises:
            ServerDeploymentError: If the ZenML server is not running or
                is unreachable.
        """
        # this will also raise ServerDeploymentNotFoundError if the server
        # does not exist
        server = self.get_server(server_name)
        provider_name = server.config.provider.value

        gc = GlobalConfiguration()
        if not server.status or not server.status.url:
            raise ServerDeploymentError(
                f"The {provider_name} {server_name} ZenML "
                f"server is not currently running or is unreachable."
            )

        store_config = RestZenStoreConfiguration(
            url=server.status.url,
            username=username,
            password=password,
            verify_ssl=verify_ssl,
        )

        if gc.store_configuration == store_config:
            logger.info(
                f"ZenML is already connected to the '{server_name}' "
                f"{provider_name} ZenML server."
            )
            return

        logger.info(
            f"Connecting ZenML to the '{server_name}' "
            f"{provider_name} ZenML server ({store_config.url})."
        )

        gc.set_store(store_config)

        logger.info(
            f"Connected ZenML to the '{server_name}' "
            f"{provider_name} ZenML server ({store_config.url})."
        )

    def disconnect_from_server(
        self,
        server_name: Optional[str] = None,
    ) -> None:
        """Disconnect from a ZenML server instance.

        Args:
            server_name: The server deployment name. If supplied, the deployer
                will check if the ZenML client is indeed connected to the server
                and disconnect only if that is the case. Otherwise the deployer
                will disconnect from any ZenML server.
        """
        gc = GlobalConfiguration()
        store_cfg = gc.store_configuration

        if store_cfg.type != StoreType.REST:
            logger.info("ZenML is not currently connected to a ZenML server.")
            return

        if server_name:
            # this will also raise ServerDeploymentNotFoundError if the server
            # does not exist
            server = self.get_server(server_name)
            provider_name = server.config.provider.value

            if not self.is_connected_to_server(server_name):
                logger.info(
                    f"ZenML is not currently connected to the '{server_name}' "
                    f"{provider_name} ZenML server."
                )
                return

            logger.info(
                f"Disconnecting ZenML from the '{server_name}' "
                f"{provider_name} ZenML server ({store_cfg.url})."
            )
        else:
            logger.info(
                f"Disconnecting ZenML from the {store_cfg.url} ZenML server."
            )

        gc.set_default_store()

        logger.info("Disconnected ZenML from the ZenML server.")

    def get_server(
        self,
        server_name: str,
    ) -> ServerDeployment:
        """Get a server deployment.

        Args:
            server_name: The server deployment name.

        Returns:
            The requested server deployment.

        Raises:
            ServerDeploymentNotFoundError: If no server deployment with the
                given name is found.
        """
        for provider in self._providers.values():
            try:
                return provider.get_server(
                    ServerDeploymentConfig(
                        name=server_name, provider=provider.TYPE
                    )
                )
            except ServerDeploymentNotFoundError:
                pass

        raise ServerDeploymentNotFoundError(
            f"Server deployment '{server_name}' not found."
        )

    def list_servers(
        self,
        server_name: Optional[str] = None,
        provider_type: Optional[ServerProviderType] = None,
    ) -> List[ServerDeployment]:
        """List all server deployments.

        Args:
            server_name: The server deployment name to filter by.
            provider_type: The server provider type to filter by.

        Returns:
            The list of server deployments.
        """
        providers: List[BaseServerProvider] = []
        if provider_type:
            providers = [self.get_provider(provider_type)]
        else:
            providers = list(self._providers.values())

        servers: List[ServerDeployment] = []
        for provider in providers:
            if server_name:
                try:
                    servers.append(
                        provider.get_server(
                            ServerDeploymentConfig(
                                name=server_name,
                                provider=provider.TYPE,
                            )
                        )
                    )
                except ServerDeploymentNotFoundError:
                    pass
            else:
                servers.extend(provider.list_servers())

        return servers

    def get_server_logs(
        self,
        server_name: str,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Retrieve the logs of a ZenML server.

        Args:
            server_name: The server deployment name.
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.
        """
        # this will also raise ServerDeploymentNotFoundError if the server
        # does not exist
        server = self.get_server(server_name)

        provider_name = server.config.provider.value
        provider = self.get_provider(server.config.provider)

        logger.info(
            f"Fetching logs from the '{server_name}' {provider_name} ZenML "
            f"server..."
        )
        return provider.get_server_logs(
            server.config, follow=follow, tail=tail
        )
connect_to_server(self, server_name, username, password, verify_ssl=True)

Connect to a ZenML server instance.

Parameters:

Name Type Description Default
server_name str

The server deployment name.

required
username str

The username to use to connect to the server.

required
password str

The password to use to connect to the server.

required
verify_ssl Union[bool, str]

Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use or the CA bundle value itself.

True

Exceptions:

Type Description
ServerDeploymentError

If the ZenML server is not running or is unreachable.

Source code in zenml/zen_server/deploy/deployer.py
def connect_to_server(
    self,
    server_name: str,
    username: str,
    password: str,
    verify_ssl: Union[bool, str] = True,
) -> None:
    """Connect to a ZenML server instance.

    Args:
        server_name: The server deployment name.
        username: The username to use to connect to the server.
        password: The password to use to connect to the server.
        verify_ssl: Either a boolean, in which case it controls whether we
            verify the server's TLS certificate, or a string, in which case
            it must be a path to a CA bundle to use or the CA bundle value
            itself.

    Raises:
        ServerDeploymentError: If the ZenML server is not running or
            is unreachable.
    """
    # this will also raise ServerDeploymentNotFoundError if the server
    # does not exist
    server = self.get_server(server_name)
    provider_name = server.config.provider.value

    gc = GlobalConfiguration()
    if not server.status or not server.status.url:
        raise ServerDeploymentError(
            f"The {provider_name} {server_name} ZenML "
            f"server is not currently running or is unreachable."
        )

    store_config = RestZenStoreConfiguration(
        url=server.status.url,
        username=username,
        password=password,
        verify_ssl=verify_ssl,
    )

    if gc.store_configuration == store_config:
        logger.info(
            f"ZenML is already connected to the '{server_name}' "
            f"{provider_name} ZenML server."
        )
        return

    logger.info(
        f"Connecting ZenML to the '{server_name}' "
        f"{provider_name} ZenML server ({store_config.url})."
    )

    gc.set_store(store_config)

    logger.info(
        f"Connected ZenML to the '{server_name}' "
        f"{provider_name} ZenML server ({store_config.url})."
    )
deploy_server(self, config, timeout=None)

Deploy a new ZenML server or update an existing deployment.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The server deployment configuration.

required
timeout Optional[int]

The timeout in seconds to wait until the deployment is successful. If not supplied, the default timeout value specified by the provider is used.

None

Returns:

Type Description
ServerDeployment

The server deployment.

Source code in zenml/zen_server/deploy/deployer.py
def deploy_server(
    self,
    config: ServerDeploymentConfig,
    timeout: Optional[int] = None,
) -> ServerDeployment:
    """Deploy a new ZenML server or update an existing deployment.

    Args:
        config: The server deployment configuration.
        timeout: The timeout in seconds to wait until the deployment is
            successful. If not supplied, the default timeout value specified
            by the provider is used.

    Returns:
        The server deployment.
    """
    # We do this here to ensure that the zenml store is always initialized
    # before the server is deployed. This is necessary because the server
    # may require access to the local store configuration or database.
    gc = GlobalConfiguration()

    _ = gc.zen_store

    try:
        self.get_server(config.name)
    except ServerDeploymentNotFoundError:
        pass
    else:
        return self.update_server(config=config, timeout=timeout)

    provider_name = config.provider.value
    provider = self.get_provider(config.provider)

    logger.info(
        f"Deploying a {provider_name} ZenML server with name "
        f"'{config.name}'."
    )
    return provider.deploy_server(config, timeout=timeout)
disconnect_from_server(self, server_name=None)

Disconnect from a ZenML server instance.

Parameters:

Name Type Description Default
server_name Optional[str]

The server deployment name. If supplied, the deployer will check if the ZenML client is indeed connected to the server and disconnect only if that is the case. Otherwise the deployer will disconnect from any ZenML server.

None
Source code in zenml/zen_server/deploy/deployer.py
def disconnect_from_server(
    self,
    server_name: Optional[str] = None,
) -> None:
    """Disconnect from a ZenML server instance.

    Args:
        server_name: The server deployment name. If supplied, the deployer
            will check if the ZenML client is indeed connected to the server
            and disconnect only if that is the case. Otherwise the deployer
            will disconnect from any ZenML server.
    """
    gc = GlobalConfiguration()
    store_cfg = gc.store_configuration

    if store_cfg.type != StoreType.REST:
        logger.info("ZenML is not currently connected to a ZenML server.")
        return

    if server_name:
        # this will also raise ServerDeploymentNotFoundError if the server
        # does not exist
        server = self.get_server(server_name)
        provider_name = server.config.provider.value

        if not self.is_connected_to_server(server_name):
            logger.info(
                f"ZenML is not currently connected to the '{server_name}' "
                f"{provider_name} ZenML server."
            )
            return

        logger.info(
            f"Disconnecting ZenML from the '{server_name}' "
            f"{provider_name} ZenML server ({store_cfg.url})."
        )
    else:
        logger.info(
            f"Disconnecting ZenML from the {store_cfg.url} ZenML server."
        )

    gc.set_default_store()

    logger.info("Disconnected ZenML from the ZenML server.")
get_provider(provider_type) classmethod

Get the server provider associated with a provider type.

Parameters:

Name Type Description Default
provider_type ServerProviderType

The server provider type.

required

Returns:

Type Description
BaseServerProvider

The server provider associated with the provider type.

Exceptions:

Type Description
ServerProviderNotFoundError

If no provider is registered for the given provider type.

Source code in zenml/zen_server/deploy/deployer.py
@classmethod
def get_provider(
    cls, provider_type: ServerProviderType
) -> BaseServerProvider:
    """Get the server provider associated with a provider type.

    Args:
        provider_type: The server provider type.

    Returns:
        The server provider associated with the provider type.

    Raises:
        ServerProviderNotFoundError: If no provider is registered for the
            given provider type.
    """
    if provider_type not in cls._providers:
        raise ServerProviderNotFoundError(
            f"Server provider '{provider_type}' is not registered."
        )
    return cls._providers[provider_type]
get_server(self, server_name)

Get a server deployment.

Parameters:

Name Type Description Default
server_name str

The server deployment name.

required

Returns:

Type Description
ServerDeployment

The requested server deployment.

Exceptions:

Type Description
ServerDeploymentNotFoundError

If no server deployment with the given name is found.

Source code in zenml/zen_server/deploy/deployer.py
def get_server(
    self,
    server_name: str,
) -> ServerDeployment:
    """Get a server deployment.

    Args:
        server_name: The server deployment name.

    Returns:
        The requested server deployment.

    Raises:
        ServerDeploymentNotFoundError: If no server deployment with the
            given name is found.
    """
    for provider in self._providers.values():
        try:
            return provider.get_server(
                ServerDeploymentConfig(
                    name=server_name, provider=provider.TYPE
                )
            )
        except ServerDeploymentNotFoundError:
            pass

    raise ServerDeploymentNotFoundError(
        f"Server deployment '{server_name}' not found."
    )
get_server_logs(self, server_name, follow=False, tail=None)

Retrieve the logs of a ZenML server.

Parameters:

Name Type Description Default
server_name str

The server deployment name.

required
follow bool

if True, the logs will be streamed as they are written

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Source code in zenml/zen_server/deploy/deployer.py
def get_server_logs(
    self,
    server_name: str,
    follow: bool = False,
    tail: Optional[int] = None,
) -> Generator[str, bool, None]:
    """Retrieve the logs of a ZenML server.

    Args:
        server_name: The server deployment name.
        follow: if True, the logs will be streamed as they are written
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.
    """
    # this will also raise ServerDeploymentNotFoundError if the server
    # does not exist
    server = self.get_server(server_name)

    provider_name = server.config.provider.value
    provider = self.get_provider(server.config.provider)

    logger.info(
        f"Fetching logs from the '{server_name}' {provider_name} ZenML "
        f"server..."
    )
    return provider.get_server_logs(
        server.config, follow=follow, tail=tail
    )
is_connected_to_server(self, server_name)

Check if the ZenML client is currently connected to a ZenML server.

Parameters:

Name Type Description Default
server_name str

The server deployment name.

required

Returns:

Type Description
bool

True if the ZenML client is connected to the ZenML server, False otherwise.

Source code in zenml/zen_server/deploy/deployer.py
def is_connected_to_server(self, server_name: str) -> bool:
    """Check if the ZenML client is currently connected to a ZenML server.

    Args:
        server_name: The server deployment name.

    Returns:
        True if the ZenML client is connected to the ZenML server, False
        otherwise.
    """
    # this will also raise ServerDeploymentNotFoundError if the server
    # does not exist
    server = self.get_server(server_name)

    gc = GlobalConfiguration()
    return (
        server.status is not None
        and server.status.url is not None
        and gc.store_configuration.url == server.status.url
    )
list_servers(self, server_name=None, provider_type=None)

List all server deployments.

Parameters:

Name Type Description Default
server_name Optional[str]

The server deployment name to filter by.

None
provider_type Optional[zenml.enums.ServerProviderType]

The server provider type to filter by.

None

Returns:

Type Description
List[zenml.zen_server.deploy.deployment.ServerDeployment]

The list of server deployments.

Source code in zenml/zen_server/deploy/deployer.py
def list_servers(
    self,
    server_name: Optional[str] = None,
    provider_type: Optional[ServerProviderType] = None,
) -> List[ServerDeployment]:
    """List all server deployments.

    Args:
        server_name: The server deployment name to filter by.
        provider_type: The server provider type to filter by.

    Returns:
        The list of server deployments.
    """
    providers: List[BaseServerProvider] = []
    if provider_type:
        providers = [self.get_provider(provider_type)]
    else:
        providers = list(self._providers.values())

    servers: List[ServerDeployment] = []
    for provider in providers:
        if server_name:
            try:
                servers.append(
                    provider.get_server(
                        ServerDeploymentConfig(
                            name=server_name,
                            provider=provider.TYPE,
                        )
                    )
                )
            except ServerDeploymentNotFoundError:
                pass
        else:
            servers.extend(provider.list_servers())

    return servers
register_provider(provider) classmethod

Register a server provider.

Parameters:

Name Type Description Default
provider Type[zenml.zen_server.deploy.base_provider.BaseServerProvider]

The server provider to register.

required

Exceptions:

Type Description
TypeError

If a provider with the same type is already registered.

Source code in zenml/zen_server/deploy/deployer.py
@classmethod
def register_provider(cls, provider: Type[BaseServerProvider]) -> None:
    """Register a server provider.

    Args:
        provider: The server provider to register.

    Raises:
        TypeError: If a provider with the same type is already registered.
    """
    if provider.TYPE in cls._providers:
        raise TypeError(
            f"Server provider '{provider.TYPE}' is already registered."
        )
    logger.debug(f"Registering server provider '{provider.TYPE}'.")
    cls._providers[provider.TYPE] = provider()
remove_server(self, server_name, timeout=None)

Tears down and removes all resources and files associated with a ZenML server deployment.

Parameters:

Name Type Description Default
server_name str

The server deployment name.

required
timeout Optional[int]

The timeout in seconds to wait until the deployment is successfully torn down. If not supplied, a provider specific default timeout value is used.

None
Source code in zenml/zen_server/deploy/deployer.py
def remove_server(
    self,
    server_name: str,
    timeout: Optional[int] = None,
) -> None:
    """Tears down and removes all resources and files associated with a ZenML server deployment.

    Args:
        server_name: The server deployment name.
        timeout: The timeout in seconds to wait until the deployment is
            successfully torn down. If not supplied, a provider specific
            default timeout value is used.
    """
    # this will also raise ServerDeploymentNotFoundError if the server
    # does not exist
    server = self.get_server(server_name)

    provider_name = server.config.provider.value
    provider = self.get_provider(server.config.provider)

    if self.is_connected_to_server(server_name):
        try:
            self.disconnect_from_server(server_name)
        except Exception as e:
            logger.warning(f"Failed to disconnect from the server: {e}")

    logger.info(
        f"Tearing down the '{server_name}' {provider_name} ZenML server."
    )
    provider.remove_server(server.config, timeout=timeout)
update_server(self, config, timeout=None)

Update an existing ZenML server deployment.

Parameters:

Name Type Description Default
config ServerDeploymentConfig

The new server deployment configuration.

required
timeout Optional[int]

The timeout in seconds to wait until the deployment is successful. If not supplied, a default timeout value of 30 seconds is used.

None

Returns:

Type Description
ServerDeployment

The updated server deployment.

Exceptions:

Type Description
ServerDeploymentExistsError

If an existing deployment with the same name but a different provider type is found.

Source code in zenml/zen_server/deploy/deployer.py
def update_server(
    self,
    config: ServerDeploymentConfig,
    timeout: Optional[int] = None,
) -> ServerDeployment:
    """Update an existing ZenML server deployment.

    Args:
        config: The new server deployment configuration.
        timeout: The timeout in seconds to wait until the deployment is
            successful. If not supplied, a default timeout value of 30
            seconds is used.

    Returns:
        The updated server deployment.

    Raises:
        ServerDeploymentExistsError: If an existing deployment with the same
            name but a different provider type is found.
    """
    # this will also raise ServerDeploymentNotFoundError if the server
    # does not exist
    existing_server = self.get_server(config.name)

    provider = self.get_provider(config.provider)
    existing_provider = existing_server.config.provider

    if existing_provider != config.provider:
        raise ServerDeploymentExistsError(
            f"A server deployment with the same name '{config.name}' but "
            f"with a different provider '{existing_provider.value}'."
            f"is already provisioned. Please choose a different name or "
            f"tear down the existing deployment."
        )

    return provider.update_server(config, timeout=timeout)

deployment

Zen Server deployment definitions.

ServerDeployment (BaseModel)

Server deployment.

Attributes:

Name Type Description
config ServerDeploymentConfig

The server deployment configuration.

status Optional[zenml.zen_server.deploy.deployment.ServerDeploymentStatus]

The server deployment status.

Source code in zenml/zen_server/deploy/deployment.py
class ServerDeployment(BaseModel):
    """Server deployment.

    Attributes:
        config: The server deployment configuration.
        status: The server deployment status.
    """

    config: ServerDeploymentConfig
    status: Optional[ServerDeploymentStatus] = None

    @property
    def is_running(self) -> bool:
        """Check if the server is running.

        Returns:
            Whether the server is running.
        """
        return (
            self.status is not None
            and self.status.status == ServiceState.ACTIVE
        )
is_running: bool property readonly

Check if the server is running.

Returns:

Type Description
bool

Whether the server is running.

ServerDeploymentConfig (BaseModel)

Generic server deployment configuration.

All server deployment configurations should inherit from this class and handle extra attributes as provider specific attributes.

Attributes:

Name Type Description
name str

Name of the server deployment.

provider ServerProviderType

The server provider type.

Source code in zenml/zen_server/deploy/deployment.py
class ServerDeploymentConfig(BaseModel):
    """Generic server deployment configuration.

    All server deployment configurations should inherit from this class and
    handle extra attributes as provider specific attributes.

    Attributes:
        name: Name of the server deployment.
        provider: The server provider type.
    """

    name: str
    provider: ServerProviderType
    model_config = ConfigDict(
        # Validate attributes when assigning them. We need to set this in order
        # to have a mix of mutable and immutable attributes
        validate_assignment=True,
        # Allow extra attributes to be set in the base class. The concrete
        # classes are responsible for validating the attributes.
        extra="allow",
    )
ServerDeploymentStatus (BaseModel)

Server deployment status.

Ideally this should convey the following information:

  • whether the server's deployment is managed by this client (i.e. if the server was deployed with zenml up)
  • for a managed deployment, the status of the deployment/tear-down, e.g. not deployed, deploying, running, deleting, deployment timeout/error, tear-down timeout/error etc.
  • for an unmanaged deployment, the operational status (i.e. whether the server is reachable)
  • the URL of the server

Attributes:

Name Type Description
status ServiceState

The status of the server deployment.

status_message Optional[str]

A message describing the last status.

connected bool

Whether the client is currently connected to this server.

url Optional[str]

The URL of the server.

Source code in zenml/zen_server/deploy/deployment.py
class ServerDeploymentStatus(BaseModel):
    """Server deployment status.

    Ideally this should convey the following information:

    * whether the server's deployment is managed by this client (i.e. if
    the server was deployed with `zenml up`)
    * for a managed deployment, the status of the deployment/tear-down, e.g.
    not deployed, deploying, running, deleting, deployment timeout/error,
    tear-down timeout/error etc.
    * for an unmanaged deployment, the operational status (i.e. whether the
    server is reachable)
    * the URL of the server

    Attributes:
        status: The status of the server deployment.
        status_message: A message describing the last status.
        connected: Whether the client is currently connected to this server.
        url: The URL of the server.
    """

    status: ServiceState
    status_message: Optional[str] = None
    connected: bool
    url: Optional[str] = None
    ca_crt: Optional[str] = None

docker special

ZenML Server Docker Deployment.

docker_provider

Zen Server docker deployer implementation.

DockerServerProvider (BaseServerProvider)

Docker ZenML server provider.

Source code in zenml/zen_server/deploy/docker/docker_provider.py
class DockerServerProvider(BaseServerProvider):
    """Docker ZenML server provider."""

    TYPE: ClassVar[ServerProviderType] = ServerProviderType.DOCKER
    CONFIG_TYPE: ClassVar[Type[ServerDeploymentConfig]] = (
        DockerServerDeploymentConfig
    )

    @classmethod
    def _get_service_configuration(
        cls,
        server_config: ServerDeploymentConfig,
    ) -> Tuple[
        ServiceConfig,
        ServiceEndpointConfig,
        ServiceEndpointHealthMonitorConfig,
    ]:
        """Construct the service configuration from a server deployment configuration.

        Args:
            server_config: server deployment configuration.

        Returns:
            The service, service endpoint and endpoint monitor configuration.
        """
        assert isinstance(server_config, DockerServerDeploymentConfig)

        return (
            DockerZenServerConfig(
                root_runtime_path=DockerZenServer.config_path(),
                singleton=True,
                image=server_config.image,
                name=server_config.name,
                server=server_config,
            ),
            ContainerServiceEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                port=server_config.port,
                allocate_port=False,
            ),
            HTTPEndpointHealthMonitorConfig(
                healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
                use_head_request=True,
            ),
        )

    def _create_service(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Create, start and return the docker ZenML server deployment service.

        Args:
            config: The server deployment configuration.
            timeout: The timeout in seconds to wait until the service is
                running.

        Returns:
            The service instance.

        Raises:
            RuntimeError: If a docker service is already running.
        """
        assert isinstance(config, DockerServerDeploymentConfig)

        if timeout is None:
            timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT

        service = DockerZenServer.get_service()
        existing_service = DockerZenServer.get_service()
        if existing_service:
            raise RuntimeError(
                f"A docker ZenML server with name '{existing_service.config.name}' "
                f"is already running. Please stop it first before starting a "
                f"new one."
            )

        (
            service_config,
            endpoint_cfg,
            monitor_cfg,
        ) = self._get_service_configuration(config)
        endpoint = ContainerServiceEndpoint(
            config=endpoint_cfg,
            monitor=HTTPEndpointHealthMonitor(
                config=monitor_cfg,
            ),
        )
        service = DockerZenServer(
            uuid=uuid4(), config=service_config, endpoint=endpoint
        )

        service.start(timeout=timeout)
        return service

    def _update_service(
        self,
        service: BaseService,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Update the docker ZenML server deployment service.

        Args:
            service: The service instance.
            config: The new server deployment configuration.
            timeout: The timeout in seconds to wait until the updated service is
                running.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT

        (
            new_config,
            new_endpoint_cfg,
            new_monitor_cfg,
        ) = self._get_service_configuration(config)

        assert service.endpoint
        assert service.endpoint.monitor

        service.stop(timeout=timeout)
        (
            service.config,
            service.endpoint.config,
            service.endpoint.monitor.config,
        ) = (
            new_config,
            new_endpoint_cfg,
            new_monitor_cfg,
        )
        service.start(timeout=timeout)

        return service

    def _start_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Start the docker ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                running.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT

        service.start(timeout=timeout)
        return service

    def _stop_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Stop the docker ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                stopped.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT

        service.stop(timeout=timeout)
        return service

    def _delete_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> None:
        """Remove the docker ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                removed.
        """
        assert isinstance(service, DockerZenServer)

        if timeout is None:
            timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT

        service.stop(timeout)
        shutil.rmtree(DockerZenServer.config_path())

    def _get_service(self, server_name: str) -> BaseService:
        """Get the docker ZenML server deployment service.

        Args:
            server_name: The server deployment name.

        Returns:
            The service instance.

        Raises:
            KeyError: If the server deployment is not found.
        """
        service = DockerZenServer.get_service()
        if service is None:
            raise KeyError("The docker ZenML server is not deployed.")

        if service.config.name != server_name:
            raise KeyError(
                "The docker ZenML server is deployed but with a different name."
            )

        return service

    def _list_services(self) -> List[BaseService]:
        """Get all service instances for all deployed ZenML servers.

        Returns:
            A list of service instances.
        """
        service = DockerZenServer.get_service()
        if service:
            return [service]
        return []

    def _get_deployment_config(
        self, service: BaseService
    ) -> ServerDeploymentConfig:
        """Recreate the server deployment configuration from a service instance.

        Args:
            service: The service instance.

        Returns:
            The server deployment configuration.
        """
        server = cast(DockerZenServer, service)
        return server.config.server
CONFIG_TYPE (ServerDeploymentConfig)

Docker server deployment configuration.

Attributes:

Name Type Description
port int

The TCP port number where the server is accepting connections.

image str

The Docker image to use for the server.

Source code in zenml/zen_server/deploy/docker/docker_provider.py
class DockerServerDeploymentConfig(ServerDeploymentConfig):
    """Docker server deployment configuration.

    Attributes:
        port: The TCP port number where the server is accepting connections.
        image: The Docker image to use for the server.
    """

    port: int = 8238
    image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE
    store: Optional[StoreConfiguration] = None
    use_legacy_dashboard: bool = DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD

    model_config = ConfigDict(extra="forbid")
docker_zen_server

Service implementation for the ZenML docker server deployment.

DockerServerDeploymentConfig (ServerDeploymentConfig)

Docker server deployment configuration.

Attributes:

Name Type Description
port int

The TCP port number where the server is accepting connections.

image str

The Docker image to use for the server.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerServerDeploymentConfig(ServerDeploymentConfig):
    """Docker server deployment configuration.

    Attributes:
        port: The TCP port number where the server is accepting connections.
        image: The Docker image to use for the server.
    """

    port: int = 8238
    image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE
    store: Optional[StoreConfiguration] = None
    use_legacy_dashboard: bool = DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD

    model_config = ConfigDict(extra="forbid")
DockerZenServer (ContainerService)

Service that can be used to start a docker ZenServer.

Attributes:

Name Type Description
config DockerZenServerConfig

service configuration

endpoint ContainerServiceEndpoint

service endpoint

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerZenServer(ContainerService):
    """Service that can be used to start a docker ZenServer.

    Attributes:
        config: service configuration
        endpoint: service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="docker_zenml_server",
        type="zen_server",
        flavor="docker",
        description="Docker ZenML server deployment",
    )

    config: DockerZenServerConfig
    endpoint: ContainerServiceEndpoint

    @classmethod
    def config_path(cls) -> str:
        """Path to the directory where the docker ZenML server files are located.

        Returns:
            Path to the docker ZenML server runtime directory.
        """
        return os.path.join(
            get_global_config_directory(),
            "zen_server",
            "docker",
        )

    @property
    def _global_config_path(self) -> str:
        """Path to the global configuration directory used by this server.

        Returns:
            Path to the global configuration directory used by this server.
        """
        return os.path.join(
            self.config_path(), SERVICE_CONTAINER_GLOBAL_CONFIG_DIR
        )

    @classmethod
    def get_service(cls) -> Optional["DockerZenServer"]:
        """Load and return the docker ZenML server service, if present.

        Returns:
            The docker ZenML server service or None, if the docker server
            deployment is not found.
        """
        config_filename = os.path.join(cls.config_path(), "service.json")
        try:
            with open(config_filename, "r") as f:
                return cast(
                    "DockerZenServer", DockerZenServer.from_json(f.read())
                )
        except FileNotFoundError:
            return None

    def _get_container_cmd(self) -> Tuple[List[str], Dict[str, str]]:
        """Get the command to run the service container.

        Override the inherited method to use a ZenML global config path inside
        the container that points to the global config copy instead of the
        one mounted from the local host.

        Returns:
            Command needed to launch the docker container and the environment
            variables to set, in the formats accepted by subprocess.Popen.
        """
        gc = GlobalConfiguration()

        cmd, env = super()._get_container_cmd()
        env[ENV_ZENML_CONFIG_PATH] = os.path.join(
            SERVICE_CONTAINER_PATH,
            SERVICE_CONTAINER_GLOBAL_CONFIG_DIR,
        )
        env[ENV_ZENML_SERVER_DEPLOYMENT_TYPE] = ServerDeploymentType.DOCKER
        env[ENV_ZENML_ANALYTICS_OPT_IN] = str(gc.analytics_opt_in)

        # Set the local stores path to the same path used by the client (mounted
        # in the container by the super class). This ensures that the server's
        # default store configuration is initialized to point at the same local
        # SQLite database as the client.
        env[ENV_ZENML_LOCAL_STORES_PATH] = os.path.join(
            SERVICE_CONTAINER_GLOBAL_CONFIG_PATH,
            LOCAL_STORES_DIRECTORY_NAME,
        )
        env[ENV_ZENML_DISABLE_DATABASE_MIGRATION] = "True"
        env[ENV_ZENML_SERVER_USE_LEGACY_DASHBOARD] = str(
            self.config.server.use_legacy_dashboard
        )
        env[ENV_ZENML_SERVER_AUTO_ACTIVATE] = "True"

        return cmd, env

    def provision(self) -> None:
        """Provision the service."""
        super().provision()

    def run(self) -> None:
        """Run the ZenML Server.

        Raises:
            ValueError: if started with a global configuration that connects to
                another ZenML server.
        """
        import uvicorn

        gc = GlobalConfiguration()
        if gc.store_configuration.type == StoreType.REST:
            raise ValueError(
                "The ZenML server cannot be started with REST store type."
            )
        logger.info(
            "Starting ZenML Server as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()

        try:
            uvicorn.run(
                ZEN_SERVER_ENTRYPOINT,
                host="0.0.0.0",  # nosec
                port=self.endpoint.config.port or 8000,
                log_level="info",
                server_header=False,
            )
        except KeyboardInterrupt:
            logger.info("ZenML Server stopped. Resuming normal execution.")
config_path() classmethod

Path to the directory where the docker ZenML server files are located.

Returns:

Type Description
str

Path to the docker ZenML server runtime directory.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
@classmethod
def config_path(cls) -> str:
    """Path to the directory where the docker ZenML server files are located.

    Returns:
        Path to the docker ZenML server runtime directory.
    """
    return os.path.join(
        get_global_config_directory(),
        "zen_server",
        "docker",
    )
get_service() classmethod

Load and return the docker ZenML server service, if present.

Returns:

Type Description
Optional[DockerZenServer]

The docker ZenML server service or None, if the docker server deployment is not found.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
@classmethod
def get_service(cls) -> Optional["DockerZenServer"]:
    """Load and return the docker ZenML server service, if present.

    Returns:
        The docker ZenML server service or None, if the docker server
        deployment is not found.
    """
    config_filename = os.path.join(cls.config_path(), "service.json")
    try:
        with open(config_filename, "r") as f:
            return cast(
                "DockerZenServer", DockerZenServer.from_json(f.read())
            )
    except FileNotFoundError:
        return None
model_post_init(self, _ModelMetaclass__context)

We need to both initialize private attributes and call the user-defined model_post_init method.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
    """We need to both initialize private attributes and call the user-defined model_post_init
    method.
    """
    init_private_attributes(self, __context)
    original_model_post_init(self, __context)
provision(self)

Provision the service.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def provision(self) -> None:
    """Provision the service."""
    super().provision()
run(self)

Run the ZenML Server.

Exceptions:

Type Description
ValueError

if started with a global configuration that connects to another ZenML server.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def run(self) -> None:
    """Run the ZenML Server.

    Raises:
        ValueError: if started with a global configuration that connects to
            another ZenML server.
    """
    import uvicorn

    gc = GlobalConfiguration()
    if gc.store_configuration.type == StoreType.REST:
        raise ValueError(
            "The ZenML server cannot be started with REST store type."
        )
    logger.info(
        "Starting ZenML Server as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        uvicorn.run(
            ZEN_SERVER_ENTRYPOINT,
            host="0.0.0.0",  # nosec
            port=self.endpoint.config.port or 8000,
            log_level="info",
            server_header=False,
        )
    except KeyboardInterrupt:
        logger.info("ZenML Server stopped. Resuming normal execution.")
DockerZenServerConfig (ContainerServiceConfig)

Docker Zen server configuration.

Attributes:

Name Type Description
server DockerServerDeploymentConfig

The deployment configuration.

Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerZenServerConfig(ContainerServiceConfig):
    """Docker Zen server configuration.

    Attributes:
        server: The deployment configuration.
    """

    server: DockerServerDeploymentConfig

exceptions

ZenML server deployment exceptions.

ServerDeploymentConfigurationError (ServerDeploymentError)

Raised when there is a ZenML server deployment configuration error .

Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentConfigurationError(ServerDeploymentError):
    """Raised when there is a ZenML server deployment configuration error ."""
ServerDeploymentError (ZenMLBaseException)

Base exception class for all ZenML server deployment related errors.

Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentError(ZenMLBaseException):
    """Base exception class for all ZenML server deployment related errors."""
ServerDeploymentExistsError (ServerDeploymentError)

Raised when trying to deploy a new ZenML server with the same name.

Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentExistsError(ServerDeploymentError):
    """Raised when trying to deploy a new ZenML server with the same name."""
ServerDeploymentNotFoundError (ServerDeploymentError)

Raised when trying to fetch a ZenML server deployment that doesn't exist.

Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentNotFoundError(ServerDeploymentError):
    """Raised when trying to fetch a ZenML server deployment that doesn't exist."""
ServerProviderNotFoundError (ServerDeploymentError)

Raised when using a ZenML server provider that doesn't exist.

Source code in zenml/zen_server/deploy/exceptions.py
class ServerProviderNotFoundError(ServerDeploymentError):
    """Raised when using a ZenML server provider that doesn't exist."""

local special

ZenML Server Local Deployment.

local_provider

Zen Server local provider implementation.

LocalServerProvider (BaseServerProvider)

Local ZenML server provider.

Source code in zenml/zen_server/deploy/local/local_provider.py
class LocalServerProvider(BaseServerProvider):
    """Local ZenML server provider."""

    TYPE: ClassVar[ServerProviderType] = ServerProviderType.LOCAL
    CONFIG_TYPE: ClassVar[Type[ServerDeploymentConfig]] = (
        LocalServerDeploymentConfig
    )

    @staticmethod
    def check_local_server_dependencies() -> None:
        """Check if local server dependencies are installed.

        Raises:
            RuntimeError: If the dependencies are not installed.
        """
        try:
            # Make sure the ZenML Server dependencies are installed
            import fastapi  # noqa
            import jwt  # noqa
            import multipart  # noqa
            import uvicorn  # noqa
        except ImportError:
            # Unable to import the ZenML Server dependencies.
            raise RuntimeError(
                "The local ZenML server provider is unavailable because the "
                "ZenML server requirements seems to be unavailable on your machine. "
                "This is probably because ZenML was installed without the optional "
                "ZenML Server dependencies. To install the missing dependencies "
                f'run `pip install "zenml[server]=={__version__}"`.'
            )

    @classmethod
    def _get_service_configuration(
        cls,
        server_config: ServerDeploymentConfig,
    ) -> Tuple[
        ServiceConfig,
        ServiceEndpointConfig,
        ServiceEndpointHealthMonitorConfig,
    ]:
        """Construct the service configuration from a server deployment configuration.

        Args:
            server_config: server deployment configuration.

        Returns:
            The service, service endpoint and endpoint monitor configuration.
        """
        assert isinstance(server_config, LocalServerDeploymentConfig)
        return (
            LocalZenServerConfig(
                root_runtime_path=LocalZenServer.config_path(),
                singleton=True,
                name=server_config.name,
                blocking=server_config.blocking,
                server=server_config,
            ),
            LocalDaemonServiceEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                ip_address=str(server_config.ip_address),
                port=server_config.port,
                allocate_port=False,
            ),
            HTTPEndpointHealthMonitorConfig(
                healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
                use_head_request=True,
            ),
        )

    def _create_service(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Create, start and return the local ZenML server deployment service.

        Args:
            config: The server deployment configuration.
            timeout: The timeout in seconds to wait until the service is
                running.

        Returns:
            The service instance.

        Raises:
            RuntimeError: If a local service is already running.
        """
        assert isinstance(config, LocalServerDeploymentConfig)

        if timeout is None:
            timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT

        self.check_local_server_dependencies()
        existing_service = LocalZenServer.get_service()
        if existing_service:
            raise RuntimeError(
                f"A local ZenML server with name '{existing_service.config.name}' "
                f"is already running. Please stop it first before starting a "
                f"new one."
            )

        (
            service_config,
            endpoint_cfg,
            monitor_cfg,
        ) = self._get_service_configuration(config)
        endpoint = LocalDaemonServiceEndpoint(
            config=endpoint_cfg,
            monitor=HTTPEndpointHealthMonitor(
                config=monitor_cfg,
            ),
        )
        service = LocalZenServer(
            uuid=uuid4(), config=service_config, endpoint=endpoint
        )
        service.start(timeout=timeout)
        return service

    def _update_service(
        self,
        service: BaseService,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Update the local ZenML server deployment service.

        Args:
            service: The service instance.
            config: The new server deployment configuration.
            timeout: The timeout in seconds to wait until the updated service is
                running.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT

        (
            new_config,
            new_endpoint_cfg,
            new_monitor_cfg,
        ) = self._get_service_configuration(config)

        assert service.endpoint
        assert service.endpoint.monitor
        service.stop(timeout=timeout)
        (
            service.config,
            service.endpoint.config,
            service.endpoint.monitor.config,
        ) = (
            new_config,
            new_endpoint_cfg,
            new_monitor_cfg,
        )
        service.start(timeout=timeout)

        return service

    def _start_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Start the local ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                running.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT

        service.start(timeout=timeout)
        return service

    def _stop_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Stop the local ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                stopped.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT

        service.stop(timeout=timeout)
        return service

    def _delete_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> None:
        """Remove the local ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                removed.
        """
        assert isinstance(service, LocalZenServer)

        if timeout is None:
            timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT

        service.stop(timeout)
        shutil.rmtree(LocalZenServer.config_path())

    def _get_service(self, server_name: str) -> BaseService:
        """Get the local ZenML server deployment service.

        Args:
            server_name: The server deployment name.

        Returns:
            The service instance.

        Raises:
            KeyError: If the server deployment is not found.
        """
        service = LocalZenServer.get_service()
        if service is None:
            raise KeyError("The local ZenML server is not deployed.")

        if service.config.name != server_name:
            raise KeyError(
                "The local ZenML server is deployed but with a different name."
            )

        return service

    def _list_services(self) -> List[BaseService]:
        """Get all service instances for all deployed ZenML servers.

        Returns:
            A list of service instances.
        """
        service = LocalZenServer.get_service()
        if service:
            return [service]
        return []

    def _get_deployment_config(
        self, service: BaseService
    ) -> ServerDeploymentConfig:
        """Recreate the server deployment configuration from a service instance.

        Args:
            service: The service instance.

        Returns:
            The server deployment configuration.
        """
        server = cast(LocalZenServer, service)
        return server.config.server
CONFIG_TYPE (ServerDeploymentConfig)

Local server deployment configuration.

Attributes:

Name Type Description
port int

The TCP port number where the server is accepting connections.

address

The IP address where the server is reachable.

blocking bool

Run the server in blocking mode instead of using a daemon process.

Source code in zenml/zen_server/deploy/local/local_provider.py
class LocalServerDeploymentConfig(ServerDeploymentConfig):
    """Local server deployment configuration.

    Attributes:
        port: The TCP port number where the server is accepting connections.
        address: The IP address where the server is reachable.
        blocking: Run the server in blocking mode instead of using a daemon
            process.
    """

    port: int = 8237
    ip_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = Field(
        default=ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS),
        union_mode="left_to_right",
    )
    blocking: bool = False
    store: Optional[StoreConfiguration] = None
    use_legacy_dashboard: bool = DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD

    model_config = ConfigDict(extra="forbid")
check_local_server_dependencies() staticmethod

Check if local server dependencies are installed.

Exceptions:

Type Description
RuntimeError

If the dependencies are not installed.

Source code in zenml/zen_server/deploy/local/local_provider.py
@staticmethod
def check_local_server_dependencies() -> None:
    """Check if local server dependencies are installed.

    Raises:
        RuntimeError: If the dependencies are not installed.
    """
    try:
        # Make sure the ZenML Server dependencies are installed
        import fastapi  # noqa
        import jwt  # noqa
        import multipart  # noqa
        import uvicorn  # noqa
    except ImportError:
        # Unable to import the ZenML Server dependencies.
        raise RuntimeError(
            "The local ZenML server provider is unavailable because the "
            "ZenML server requirements seems to be unavailable on your machine. "
            "This is probably because ZenML was installed without the optional "
            "ZenML Server dependencies. To install the missing dependencies "
            f'run `pip install "zenml[server]=={__version__}"`.'
        )
local_zen_server

Local ZenML server deployment service implementation.

LocalServerDeploymentConfig (ServerDeploymentConfig)

Local server deployment configuration.

Attributes:

Name Type Description
port int

The TCP port number where the server is accepting connections.

address

The IP address where the server is reachable.

blocking bool

Run the server in blocking mode instead of using a daemon process.

Source code in zenml/zen_server/deploy/local/local_zen_server.py
class LocalServerDeploymentConfig(ServerDeploymentConfig):
    """Local server deployment configuration.

    Attributes:
        port: The TCP port number where the server is accepting connections.
        address: The IP address where the server is reachable.
        blocking: Run the server in blocking mode instead of using a daemon
            process.
    """

    port: int = 8237
    ip_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = Field(
        default=ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS),
        union_mode="left_to_right",
    )
    blocking: bool = False
    store: Optional[StoreConfiguration] = None
    use_legacy_dashboard: bool = DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD

    model_config = ConfigDict(extra="forbid")
LocalZenServer (LocalDaemonService)

Service daemon that can be used to start a local ZenML Server.

Attributes:

Name Type Description
config LocalZenServerConfig

service configuration

endpoint LocalDaemonServiceEndpoint

optional service endpoint

Source code in zenml/zen_server/deploy/local/local_zen_server.py
class LocalZenServer(LocalDaemonService):
    """Service daemon that can be used to start a local ZenML Server.

    Attributes:
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="local_zenml_server",
        type="zen_server",
        flavor="local",
        description="Local ZenML server deployment",
    )

    config: LocalZenServerConfig
    endpoint: LocalDaemonServiceEndpoint

    @classmethod
    def config_path(cls) -> str:
        """Path to the directory where the local ZenML server files are located.

        Returns:
            Path to the local ZenML server runtime directory.
        """
        return os.path.join(
            get_global_config_directory(),
            "zen_server",
            "local",
        )

    @property
    def _global_config_path(self) -> str:
        """Path to the global configuration directory used by this server.

        Returns:
            Path to the global configuration directory used by this server.
        """
        return os.path.join(self.config_path(), ".zenconfig")

    @classmethod
    def get_service(cls) -> Optional["LocalZenServer"]:
        """Load and return the local ZenML server service, if present.

        Returns:
            The local ZenML server service or None, if the local server
            deployment is not found.
        """
        config_filename = os.path.join(cls.config_path(), "service.json")
        try:
            with open(config_filename, "r") as f:
                return cast(
                    "LocalZenServer", LocalZenServer.from_json(f.read())
                )
        except FileNotFoundError:
            return None

    def _get_daemon_cmd(self) -> Tuple[List[str], Dict[str, str]]:
        """Get the command to start the daemon.

        Overrides the base class implementation to add the environment variable
        that forces the ZenML server to use the copied global config.

        Returns:
            The command to start the daemon and the environment variables to
            set for the command.
        """
        cmd, env = super()._get_daemon_cmd()
        env[ENV_ZENML_CONFIG_PATH] = self._global_config_path
        env[ENV_ZENML_SERVER_DEPLOYMENT_TYPE] = ServerDeploymentType.LOCAL
        # Set the local stores path to the same path used by the client. This
        # ensures that the server's default store configuration is initialized
        # to point at the same local SQLite database as the client.
        env[ENV_ZENML_LOCAL_STORES_PATH] = (
            GlobalConfiguration().local_stores_path
        )
        env[ENV_ZENML_DISABLE_DATABASE_MIGRATION] = "True"
        env[ENV_ZENML_SERVER_USE_LEGACY_DASHBOARD] = str(
            self.config.server.use_legacy_dashboard
        )
        env[ENV_ZENML_SERVER_AUTO_ACTIVATE] = "True"

        return cmd, env

    def provision(self) -> None:
        """Provision the service."""
        super().provision()

    def start(self, timeout: int = 0) -> None:
        """Start the service and optionally wait for it to become active.

        Args:
            timeout: amount of time to wait for the service to become active.
                If set to 0, the method will return immediately after checking
                the service status.
        """
        if not self.config.blocking:
            super().start(timeout)
        else:
            # In the blocking mode, we need to temporarily set the environment
            # variables for the running process to make it look like the server
            # is running in a separate environment (i.e. using a different
            # global configuration path). This is necessary to avoid polluting
            # the client environment with the server's configuration.
            local_stores_path = GlobalConfiguration().local_stores_path
            GlobalConfiguration._reset_instance()
            Client._reset_instance()
            original_config_path = os.environ.get(ENV_ZENML_CONFIG_PATH)
            os.environ[ENV_ZENML_CONFIG_PATH] = self._global_config_path
            # Set the local stores path to the same path used by the client.
            # This ensures that the server's default store configuration is
            # initialized to point at the same local SQLite database as the
            # client.
            os.environ[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
            try:
                self.run()
            finally:
                # Restore the original client environment variables
                if original_config_path:
                    os.environ[ENV_ZENML_CONFIG_PATH] = original_config_path
                else:
                    del os.environ[ENV_ZENML_CONFIG_PATH]
                del os.environ[ENV_ZENML_LOCAL_STORES_PATH]
                GlobalConfiguration._reset_instance()
                Client._reset_instance()

    def run(self) -> None:
        """Run the ZenML Server.

        Raises:
            ValueError: if started with a global configuration that connects to
                another ZenML server.
        """
        import uvicorn

        gc = GlobalConfiguration()
        if gc.store_configuration.type == StoreType.REST:
            raise ValueError(
                "The ZenML server cannot be started with REST store type."
            )
        logger.info(
            "Starting ZenML Server as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()

        try:
            uvicorn.run(
                ZEN_SERVER_ENTRYPOINT,
                host=self.endpoint.config.ip_address,
                port=self.endpoint.config.port or 8000,
                log_level="info",
                server_header=False,
            )
        except KeyboardInterrupt:
            logger.info("ZenML Server stopped. Resuming normal execution.")
config_path() classmethod

Path to the directory where the local ZenML server files are located.

Returns:

Type Description
str

Path to the local ZenML server runtime directory.

Source code in zenml/zen_server/deploy/local/local_zen_server.py
@classmethod
def config_path(cls) -> str:
    """Path to the directory where the local ZenML server files are located.

    Returns:
        Path to the local ZenML server runtime directory.
    """
    return os.path.join(
        get_global_config_directory(),
        "zen_server",
        "local",
    )
get_service() classmethod

Load and return the local ZenML server service, if present.

Returns:

Type Description
Optional[LocalZenServer]

The local ZenML server service or None, if the local server deployment is not found.

Source code in zenml/zen_server/deploy/local/local_zen_server.py
@classmethod
def get_service(cls) -> Optional["LocalZenServer"]:
    """Load and return the local ZenML server service, if present.

    Returns:
        The local ZenML server service or None, if the local server
        deployment is not found.
    """
    config_filename = os.path.join(cls.config_path(), "service.json")
    try:
        with open(config_filename, "r") as f:
            return cast(
                "LocalZenServer", LocalZenServer.from_json(f.read())
            )
    except FileNotFoundError:
        return None
provision(self)

Provision the service.

Source code in zenml/zen_server/deploy/local/local_zen_server.py
def provision(self) -> None:
    """Provision the service."""
    super().provision()
run(self)

Run the ZenML Server.

Exceptions:

Type Description
ValueError

if started with a global configuration that connects to another ZenML server.

Source code in zenml/zen_server/deploy/local/local_zen_server.py
def run(self) -> None:
    """Run the ZenML Server.

    Raises:
        ValueError: if started with a global configuration that connects to
            another ZenML server.
    """
    import uvicorn

    gc = GlobalConfiguration()
    if gc.store_configuration.type == StoreType.REST:
        raise ValueError(
            "The ZenML server cannot be started with REST store type."
        )
    logger.info(
        "Starting ZenML Server as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        uvicorn.run(
            ZEN_SERVER_ENTRYPOINT,
            host=self.endpoint.config.ip_address,
            port=self.endpoint.config.port or 8000,
            log_level="info",
            server_header=False,
        )
    except KeyboardInterrupt:
        logger.info("ZenML Server stopped. Resuming normal execution.")
start(self, timeout=0)

Start the service and optionally wait for it to become active.

Parameters:

Name Type Description Default
timeout int

amount of time to wait for the service to become active. If set to 0, the method will return immediately after checking the service status.

0
Source code in zenml/zen_server/deploy/local/local_zen_server.py
def start(self, timeout: int = 0) -> None:
    """Start the service and optionally wait for it to become active.

    Args:
        timeout: amount of time to wait for the service to become active.
            If set to 0, the method will return immediately after checking
            the service status.
    """
    if not self.config.blocking:
        super().start(timeout)
    else:
        # In the blocking mode, we need to temporarily set the environment
        # variables for the running process to make it look like the server
        # is running in a separate environment (i.e. using a different
        # global configuration path). This is necessary to avoid polluting
        # the client environment with the server's configuration.
        local_stores_path = GlobalConfiguration().local_stores_path
        GlobalConfiguration._reset_instance()
        Client._reset_instance()
        original_config_path = os.environ.get(ENV_ZENML_CONFIG_PATH)
        os.environ[ENV_ZENML_CONFIG_PATH] = self._global_config_path
        # Set the local stores path to the same path used by the client.
        # This ensures that the server's default store configuration is
        # initialized to point at the same local SQLite database as the
        # client.
        os.environ[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
        try:
            self.run()
        finally:
            # Restore the original client environment variables
            if original_config_path:
                os.environ[ENV_ZENML_CONFIG_PATH] = original_config_path
            else:
                del os.environ[ENV_ZENML_CONFIG_PATH]
            del os.environ[ENV_ZENML_LOCAL_STORES_PATH]
            GlobalConfiguration._reset_instance()
            Client._reset_instance()
LocalZenServerConfig (LocalDaemonServiceConfig)

Local Zen server configuration.

Attributes:

Name Type Description
server LocalServerDeploymentConfig

The deployment configuration.

Source code in zenml/zen_server/deploy/local/local_zen_server.py
class LocalZenServerConfig(LocalDaemonServiceConfig):
    """Local Zen server configuration.

    Attributes:
        server: The deployment configuration.
    """

    server: LocalServerDeploymentConfig

terraform special

ZenML Server Terraform Deployment.

providers special

ZenML Server Terraform Providers.

aws_provider

Zen Server AWS Terraform deployer implementation.

AWSServerDeploymentConfig (TerraformServerDeploymentConfig)

AWS server deployment configuration.

Attributes:

Name Type Description
region str

The AWS region to deploy to.

rds_name str

The name of the RDS instance to create

db_name str

Name of RDS database to create.

db_type str

Type of RDS database to create.

db_version str

Version of RDS database to create.

db_instance_class str

Instance class of RDS database to create.

db_allocated_storage int

Allocated storage of RDS database to create.

Source code in zenml/zen_server/deploy/terraform/providers/aws_provider.py
class AWSServerDeploymentConfig(TerraformServerDeploymentConfig):
    """AWS server deployment configuration.

    Attributes:
        region: The AWS region to deploy to.
        rds_name: The name of the RDS instance to create
        db_name: Name of RDS database to create.
        db_type: Type of RDS database to create.
        db_version: Version of RDS database to create.
        db_instance_class: Instance class of RDS database to create.
        db_allocated_storage: Allocated storage of RDS database to create.
    """

    region: str = "eu-west-1"
    rds_name: str = "zenmlserver"
    db_name: str = "zenmlserver"
    db_type: str = "mysql"
    db_version: str = "5.7.38"
    db_instance_class: str = "db.t3.micro"
    db_allocated_storage: int = 5
AWSServerProvider (TerraformServerProvider)

AWS ZenML server provider.

Source code in zenml/zen_server/deploy/terraform/providers/aws_provider.py
class AWSServerProvider(TerraformServerProvider):
    """AWS ZenML server provider."""

    TYPE: ClassVar[ServerProviderType] = ServerProviderType.AWS
    CONFIG_TYPE: ClassVar[Type[TerraformServerDeploymentConfig]] = (
        AWSServerDeploymentConfig
    )
CONFIG_TYPE (TerraformServerDeploymentConfig)

AWS server deployment configuration.

Attributes:

Name Type Description
region str

The AWS region to deploy to.

rds_name str

The name of the RDS instance to create

db_name str

Name of RDS database to create.

db_type str

Type of RDS database to create.

db_version str

Version of RDS database to create.

db_instance_class str

Instance class of RDS database to create.

db_allocated_storage int

Allocated storage of RDS database to create.

Source code in zenml/zen_server/deploy/terraform/providers/aws_provider.py
class AWSServerDeploymentConfig(TerraformServerDeploymentConfig):
    """AWS server deployment configuration.

    Attributes:
        region: The AWS region to deploy to.
        rds_name: The name of the RDS instance to create
        db_name: Name of RDS database to create.
        db_type: Type of RDS database to create.
        db_version: Version of RDS database to create.
        db_instance_class: Instance class of RDS database to create.
        db_allocated_storage: Allocated storage of RDS database to create.
    """

    region: str = "eu-west-1"
    rds_name: str = "zenmlserver"
    db_name: str = "zenmlserver"
    db_type: str = "mysql"
    db_version: str = "5.7.38"
    db_instance_class: str = "db.t3.micro"
    db_allocated_storage: int = 5
azure_provider

Zen Server Azure Terraform deployer implementation.

AzureServerDeploymentConfig (TerraformServerDeploymentConfig)

Azure server deployment configuration.

Attributes:

Name Type Description
resource_group str

The Azure resource_group to deploy to.

db_instance_name str

The name of the Flexible MySQL instance to create

db_name str

Name of RDS database to create.

db_version str

Version of MySQL database to create.

db_sku_name str

The sku_name for the database resource.

db_disk_size int

Allocated storage of MySQL database to create.

Source code in zenml/zen_server/deploy/terraform/providers/azure_provider.py
class AzureServerDeploymentConfig(TerraformServerDeploymentConfig):
    """Azure server deployment configuration.

    Attributes:
        resource_group: The Azure resource_group to deploy to.
        db_instance_name: The name of the Flexible MySQL instance to create
        db_name: Name of RDS database to create.
        db_version: Version of MySQL database to create.
        db_sku_name: The sku_name for the database resource.
        db_disk_size: Allocated storage of MySQL database to create.
    """

    resource_group: str = "zenml"
    db_instance_name: str = "zenmlserver"
    db_name: str = "zenmlserver"
    db_version: str = "5.7"
    db_sku_name: str = "B_Standard_B1s"
    db_disk_size: int = 20
AzureServerProvider (TerraformServerProvider)

Azure ZenML server provider.

Source code in zenml/zen_server/deploy/terraform/providers/azure_provider.py
class AzureServerProvider(TerraformServerProvider):
    """Azure ZenML server provider."""

    TYPE: ClassVar[ServerProviderType] = ServerProviderType.AZURE
    CONFIG_TYPE: ClassVar[Type[TerraformServerDeploymentConfig]] = (
        AzureServerDeploymentConfig
    )
CONFIG_TYPE (TerraformServerDeploymentConfig)

Azure server deployment configuration.

Attributes:

Name Type Description
resource_group str

The Azure resource_group to deploy to.

db_instance_name str

The name of the Flexible MySQL instance to create

db_name str

Name of RDS database to create.

db_version str

Version of MySQL database to create.

db_sku_name str

The sku_name for the database resource.

db_disk_size int

Allocated storage of MySQL database to create.

Source code in zenml/zen_server/deploy/terraform/providers/azure_provider.py
class AzureServerDeploymentConfig(TerraformServerDeploymentConfig):
    """Azure server deployment configuration.

    Attributes:
        resource_group: The Azure resource_group to deploy to.
        db_instance_name: The name of the Flexible MySQL instance to create
        db_name: Name of RDS database to create.
        db_version: Version of MySQL database to create.
        db_sku_name: The sku_name for the database resource.
        db_disk_size: Allocated storage of MySQL database to create.
    """

    resource_group: str = "zenml"
    db_instance_name: str = "zenmlserver"
    db_name: str = "zenmlserver"
    db_version: str = "5.7"
    db_sku_name: str = "B_Standard_B1s"
    db_disk_size: int = 20
gcp_provider

Zen Server GCP Terraform deployer implementation.

GCPServerDeploymentConfig (TerraformServerDeploymentConfig)

GCP server deployment configuration.

Attributes:

Name Type Description
project_id str

The project in GCP to deploy the server to.

region str

The GCP region to deploy to.

cloudsql_name str

The name of the CloudSQL instance to create

db_name str

Name of CloudSQL database to create.

db_instance_tier str

Instance class of CloudSQL database to create.

db_disk_size int

Allocated storage of CloudSQL database to create.

Source code in zenml/zen_server/deploy/terraform/providers/gcp_provider.py
class GCPServerDeploymentConfig(TerraformServerDeploymentConfig):
    """GCP server deployment configuration.

    Attributes:
        project_id: The project in GCP to deploy the server to.
        region: The GCP region to deploy to.
        cloudsql_name: The name of the CloudSQL instance to create
        db_name: Name of CloudSQL database to create.
        db_instance_tier: Instance class of CloudSQL database to create.
        db_disk_size: Allocated storage of CloudSQL database to create.
    """

    project_id: str
    region: str = "europe-west3"
    cloudsql_name: str = "zenmlserver"
    db_name: str = "zenmlserver"
    db_instance_tier: str = "db-n1-standard-1"
    db_disk_size: int = 10
GCPServerProvider (TerraformServerProvider)

GCP ZenML server provider.

Source code in zenml/zen_server/deploy/terraform/providers/gcp_provider.py
class GCPServerProvider(TerraformServerProvider):
    """GCP ZenML server provider."""

    TYPE: ClassVar[ServerProviderType] = ServerProviderType.GCP
    CONFIG_TYPE: ClassVar[Type[TerraformServerDeploymentConfig]] = (
        GCPServerDeploymentConfig
    )
CONFIG_TYPE (TerraformServerDeploymentConfig)

GCP server deployment configuration.

Attributes:

Name Type Description
project_id str

The project in GCP to deploy the server to.

region str

The GCP region to deploy to.

cloudsql_name str

The name of the CloudSQL instance to create

db_name str

Name of CloudSQL database to create.

db_instance_tier str

Instance class of CloudSQL database to create.

db_disk_size int

Allocated storage of CloudSQL database to create.

Source code in zenml/zen_server/deploy/terraform/providers/gcp_provider.py
class GCPServerDeploymentConfig(TerraformServerDeploymentConfig):
    """GCP server deployment configuration.

    Attributes:
        project_id: The project in GCP to deploy the server to.
        region: The GCP region to deploy to.
        cloudsql_name: The name of the CloudSQL instance to create
        db_name: Name of CloudSQL database to create.
        db_instance_tier: Instance class of CloudSQL database to create.
        db_disk_size: Allocated storage of CloudSQL database to create.
    """

    project_id: str
    region: str = "europe-west3"
    cloudsql_name: str = "zenmlserver"
    db_name: str = "zenmlserver"
    db_instance_tier: str = "db-n1-standard-1"
    db_disk_size: int = 10
terraform_provider

Zen Server terraform deployer implementation.

TerraformServerProvider (BaseServerProvider)

Terraform ZenML server provider.

Source code in zenml/zen_server/deploy/terraform/providers/terraform_provider.py
class TerraformServerProvider(BaseServerProvider):
    """Terraform ZenML server provider."""

    CONFIG_TYPE: ClassVar[Type[ServerDeploymentConfig]] = (
        TerraformServerDeploymentConfig
    )

    @staticmethod
    def _get_server_recipe_root_path() -> str:
        """Get the server recipe root path.

        The Terraform recipe files for all terraform server providers are
        located in a folder relative to the `zenml.zen_server.deploy.terraform`
        Python module.

        Returns:
            The server recipe root path.
        """
        import zenml.zen_server.deploy.terraform as terraform_module

        root_path = os.path.join(
            os.path.dirname(terraform_module.__file__),
            TERRAFORM_ZENML_SERVER_RECIPE_SUBPATH,
        )
        return root_path

    @classmethod
    def _get_service_configuration(
        cls,
        server_config: ServerDeploymentConfig,
    ) -> Tuple[
        ServiceConfig,
        ServiceEndpointConfig,
        ServiceEndpointHealthMonitorConfig,
    ]:
        """Construct the service configuration from a server deployment configuration.

        Args:
            server_config: server deployment configuration.

        Returns:
            The service configuration.
        """
        assert isinstance(server_config, TerraformServerDeploymentConfig)

        return (
            TerraformZenServerConfig(
                name=server_config.name,
                root_runtime_path=TERRAFORM_ZENML_SERVER_CONFIG_PATH,
                singleton=True,
                directory_path=os.path.join(
                    cls._get_server_recipe_root_path(),
                    server_config.provider,
                ),
                log_level=server_config.log_level,
                variables_file_path=TERRAFORM_VALUES_FILE_PATH,
                server=server_config,
            ),
            ServiceEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                allocate_port=False,
            ),
            HTTPEndpointHealthMonitorConfig(
                healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
                use_head_request=True,
            ),
        )

    def _create_service(
        self,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Create, start and return the terraform ZenML server deployment service.

        Args:
            config: The server deployment configuration.
            timeout: The timeout in seconds to wait until the service is
                running.

        Returns:
            The service instance.

        Raises:
            RuntimeError: If a terraform service is already running.
        """
        assert isinstance(config, TerraformServerDeploymentConfig)

        if timeout is None:
            timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT

        existing_service = TerraformZenServer.get_service()
        if existing_service:
            raise RuntimeError(
                f"A terraform ZenML server with name '{existing_service.config.name}' "
                f"is already running. Please stop it first before starting a "
                f"new one."
            )

        (
            service_config,
            endpoint_cfg,
            monitor_cfg,
        ) = self._get_service_configuration(config)

        service = TerraformZenServer(uuid=uuid4(), config=service_config)

        service.start(timeout=timeout)
        return service

    def _update_service(
        self,
        service: BaseService,
        config: ServerDeploymentConfig,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Update the terraform ZenML server deployment service.

        Args:
            service: The service instance.
            config: The new server deployment configuration.
            timeout: The timeout in seconds to wait until the updated service is
                running.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT

        (
            new_config,
            endpoint_cfg,
            monitor_cfg,
        ) = self._get_service_configuration(config)

        assert isinstance(new_config, TerraformZenServerConfig)
        assert isinstance(service, TerraformZenServer)

        # preserve the server ID across updates
        service.config = new_config
        service.start(timeout=timeout)

        return service

    def _start_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Start the terraform ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                running.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT

        service.start(timeout=timeout)
        return service

    def _stop_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> BaseService:
        """Stop the terraform ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                stopped.

        Returns:
            The updated service instance.
        """
        if timeout is None:
            timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT

        service.stop(timeout=timeout)
        return service

    def _delete_service(
        self,
        service: BaseService,
        timeout: Optional[int] = None,
    ) -> None:
        """Remove the terraform ZenML server deployment service.

        Args:
            service: The service instance.
            timeout: The timeout in seconds to wait until the service is
                removed.
        """
        assert isinstance(service, TerraformZenServer)

        if timeout is None:
            timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT

        service.stop(timeout)

    def _get_service(self, server_name: str) -> BaseService:
        """Get the terraform ZenML server deployment service.

        Args:
            server_name: The server deployment name.

        Returns:
            The service instance.

        Raises:
            KeyError: If the server deployment is not found.
        """
        service = TerraformZenServer.get_service()
        if service is None:
            raise KeyError("The terraform ZenML server is not deployed.")

        if service.config.server.name != server_name:
            raise KeyError(
                "The terraform ZenML server is deployed but with a different name."
            )
        return service

    def _list_services(self) -> List[BaseService]:
        """Get all service instances for all deployed ZenML servers.

        Returns:
            A list of service instances.
        """
        service = TerraformZenServer.get_service()
        if service:
            return [service]
        return []

    def _get_deployment_config(
        self, service: BaseService
    ) -> ServerDeploymentConfig:
        """Recreate the server deployment configuration from a service instance.

        Args:
            service: The service instance.

        Returns:
            The server deployment configuration.
        """
        server = cast(TerraformZenServer, service)
        return server.config.server

    def _get_deployment_status(
        self, service: BaseService
    ) -> ServerDeploymentStatus:
        """Get the status of a server deployment from its service.

        Args:
            service: The server deployment service.

        Returns:
            The status of the server deployment.
        """
        gc = GlobalConfiguration()
        url: Optional[str] = None
        service = cast(TerraformZenServer, service)
        ca_crt = None
        if service.is_running:
            url = service.get_server_url()
            ca_crt = service.get_certificate()
        connected = url is not None and gc.store_configuration.url == url

        return ServerDeploymentStatus(
            url=url,
            status=service.status.state,
            status_message=service.status.last_error,
            connected=connected,
            ca_crt=ca_crt,
        )
CONFIG_TYPE (ServerDeploymentConfig)

Terraform server deployment configuration.

Attributes:

Name Type Description
log_level str

The log level to set the terraform client to. Choose one of TRACE, DEBUG, INFO, WARN or ERROR (case insensitive).

helm_chart str

The path to the ZenML server helm chart to use for deployment.

zenmlserver_image_repo str

The repository to use for the zenml server.

zenmlserver_image_tag str

The tag to use for the zenml server docker image.

namespace str

The Kubernetes namespace to deploy the ZenML server to.

kubectl_config_path str

The path to the kubectl config file to use for deployment.

ingress_tls bool

Whether to use TLS for the ingress.

ingress_tls_generate_certs bool

Whether to generate self-signed TLS certificates for the ingress.

ingress_tls_secret_name str

The name of the Kubernetes secret to use for the ingress.

create_ingress_controller bool

Whether to deploy an nginx ingress controller as part of the deployment.

ingress_controller_ip str

The ingress controller IP to use for the ingress self-signed certificate and to compute the ZenML server URL.

deploy_db bool

Whether to create a SQL database service as part of the recipe.

database_username str

The username for the database.

database_password str

The password for the database.

database_url str

The URL of the RDS instance to use for the ZenML server.

database_ssl_ca str

The path to the SSL CA certificate to use for the database connection.

database_ssl_cert str

The path to the client SSL certificate to use for the database connection.

database_ssl_key str

The path to the client SSL key to use for the database connection.

database_ssl_verify_server_cert bool

Whether to verify the database server SSL certificate.

analytics_opt_in bool

Whether to enable analytics.

Source code in zenml/zen_server/deploy/terraform/providers/terraform_provider.py
class TerraformServerDeploymentConfig(ServerDeploymentConfig):
    """Terraform server deployment configuration.

    Attributes:
        log_level: The log level to set the terraform client to. Choose one of
            TRACE, DEBUG, INFO, WARN or ERROR (case insensitive).
        helm_chart: The path to the ZenML server helm chart to use for
            deployment.
        zenmlserver_image_repo: The repository to use for the zenml server.
        zenmlserver_image_tag: The tag to use for the zenml server docker
            image.
        namespace: The Kubernetes namespace to deploy the ZenML server to.
        kubectl_config_path: The path to the kubectl config file to use for
            deployment.
        ingress_tls: Whether to use TLS for the ingress.
        ingress_tls_generate_certs: Whether to generate self-signed TLS
            certificates for the ingress.
        ingress_tls_secret_name: The name of the Kubernetes secret to use for
            the ingress.
        create_ingress_controller: Whether to deploy an nginx ingress
            controller as part of the deployment.
        ingress_controller_ip: The ingress controller IP to use for
            the ingress self-signed certificate and to compute the ZenML server
            URL.
        deploy_db: Whether to create a SQL database service as part of the recipe.
        database_username: The username for the database.
        database_password: The password for the database.
        database_url: The URL of the RDS instance to use for the ZenML server.
        database_ssl_ca: The path to the SSL CA certificate to use for the
            database connection.
        database_ssl_cert: The path to the client SSL certificate to use for the
            database connection.
        database_ssl_key: The path to the client SSL key to use for the
            database connection.
        database_ssl_verify_server_cert: Whether to verify the database server
            SSL certificate.
        analytics_opt_in: Whether to enable analytics.
    """

    log_level: str = "ERROR"

    helm_chart: str = get_helm_chart_path()
    zenmlserver_image_repo: str = "zenmldocker/zenml-server"
    zenmlserver_image_tag: str = "latest"
    namespace: str = "zenmlserver"
    kubectl_config_path: str = os.path.join(
        str(Path.home()), ".kube", "config"
    )
    ingress_tls: bool = False
    ingress_tls_generate_certs: bool = True
    ingress_tls_secret_name: str = "zenml-tls-certs"
    create_ingress_controller: bool = True
    ingress_controller_ip: str = ""
    deploy_db: bool = True
    database_username: str = "user"
    database_password: str = ""
    database_url: str = ""
    database_ssl_ca: str = ""
    database_ssl_cert: str = ""
    database_ssl_key: str = ""
    database_ssl_verify_server_cert: bool = True
    analytics_opt_in: bool = True
    model_config = ConfigDict(extra="allow")
terraform_zen_server

Service implementation for the ZenML terraform server deployment.

TerraformServerDeploymentConfig (ServerDeploymentConfig)

Terraform server deployment configuration.

Attributes:

Name Type Description
log_level str

The log level to set the terraform client to. Choose one of TRACE, DEBUG, INFO, WARN or ERROR (case insensitive).

helm_chart str

The path to the ZenML server helm chart to use for deployment.

zenmlserver_image_repo str

The repository to use for the zenml server.

zenmlserver_image_tag str

The tag to use for the zenml server docker image.

namespace str

The Kubernetes namespace to deploy the ZenML server to.

kubectl_config_path str

The path to the kubectl config file to use for deployment.

ingress_tls bool

Whether to use TLS for the ingress.

ingress_tls_generate_certs bool

Whether to generate self-signed TLS certificates for the ingress.

ingress_tls_secret_name str

The name of the Kubernetes secret to use for the ingress.

create_ingress_controller bool

Whether to deploy an nginx ingress controller as part of the deployment.

ingress_controller_ip str

The ingress controller IP to use for the ingress self-signed certificate and to compute the ZenML server URL.

deploy_db bool

Whether to create a SQL database service as part of the recipe.

database_username str

The username for the database.

database_password str

The password for the database.

database_url str

The URL of the RDS instance to use for the ZenML server.

database_ssl_ca str

The path to the SSL CA certificate to use for the database connection.

database_ssl_cert str

The path to the client SSL certificate to use for the database connection.

database_ssl_key str

The path to the client SSL key to use for the database connection.

database_ssl_verify_server_cert bool

Whether to verify the database server SSL certificate.

analytics_opt_in bool

Whether to enable analytics.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class TerraformServerDeploymentConfig(ServerDeploymentConfig):
    """Terraform server deployment configuration.

    Attributes:
        log_level: The log level to set the terraform client to. Choose one of
            TRACE, DEBUG, INFO, WARN or ERROR (case insensitive).
        helm_chart: The path to the ZenML server helm chart to use for
            deployment.
        zenmlserver_image_repo: The repository to use for the zenml server.
        zenmlserver_image_tag: The tag to use for the zenml server docker
            image.
        namespace: The Kubernetes namespace to deploy the ZenML server to.
        kubectl_config_path: The path to the kubectl config file to use for
            deployment.
        ingress_tls: Whether to use TLS for the ingress.
        ingress_tls_generate_certs: Whether to generate self-signed TLS
            certificates for the ingress.
        ingress_tls_secret_name: The name of the Kubernetes secret to use for
            the ingress.
        create_ingress_controller: Whether to deploy an nginx ingress
            controller as part of the deployment.
        ingress_controller_ip: The ingress controller IP to use for
            the ingress self-signed certificate and to compute the ZenML server
            URL.
        deploy_db: Whether to create a SQL database service as part of the recipe.
        database_username: The username for the database.
        database_password: The password for the database.
        database_url: The URL of the RDS instance to use for the ZenML server.
        database_ssl_ca: The path to the SSL CA certificate to use for the
            database connection.
        database_ssl_cert: The path to the client SSL certificate to use for the
            database connection.
        database_ssl_key: The path to the client SSL key to use for the
            database connection.
        database_ssl_verify_server_cert: Whether to verify the database server
            SSL certificate.
        analytics_opt_in: Whether to enable analytics.
    """

    log_level: str = "ERROR"

    helm_chart: str = get_helm_chart_path()
    zenmlserver_image_repo: str = "zenmldocker/zenml-server"
    zenmlserver_image_tag: str = "latest"
    namespace: str = "zenmlserver"
    kubectl_config_path: str = os.path.join(
        str(Path.home()), ".kube", "config"
    )
    ingress_tls: bool = False
    ingress_tls_generate_certs: bool = True
    ingress_tls_secret_name: str = "zenml-tls-certs"
    create_ingress_controller: bool = True
    ingress_controller_ip: str = ""
    deploy_db: bool = True
    database_username: str = "user"
    database_password: str = ""
    database_url: str = ""
    database_ssl_ca: str = ""
    database_ssl_cert: str = ""
    database_ssl_key: str = ""
    database_ssl_verify_server_cert: bool = True
    analytics_opt_in: bool = True
    model_config = ConfigDict(extra="allow")
TerraformZenServer (TerraformService)

Service that can be used to start a terraform ZenServer.

Attributes:

Name Type Description
config TerraformZenServerConfig

service configuration

endpoint Optional[zenml.services.service_endpoint.BaseServiceEndpoint]

service endpoint

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class TerraformZenServer(TerraformService):
    """Service that can be used to start a terraform ZenServer.

    Attributes:
        config: service configuration
        endpoint: service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="terraform_zenml_server",
        type="zen_server",
        flavor="terraform",
        description="Terraform ZenML server deployment",
    )

    config: TerraformZenServerConfig

    @classmethod
    def get_service(cls) -> Optional["TerraformZenServer"]:
        """Load and return the terraform ZenML server service, if present.

        Returns:
            The terraform ZenML server service or None, if the terraform server
            deployment is not found.
        """
        try:
            with open(TERRAFORM_ZENML_SERVER_CONFIG_FILENAME, "r") as f:
                return cast(
                    TerraformZenServer, TerraformZenServer.from_json(f.read())
                )
        except FileNotFoundError:
            return None

    def get_vars(self) -> Dict[str, Any]:
        """Get variables as a dictionary.

        Returns:
            A dictionary of variables to use for the Terraform deployment.
        """
        # get the contents of the server deployment config as dict
        filter_vars = ["log_level", "provider"]
        # filter keys that are not modeled as terraform deployment vars
        vars = {
            k: str(v) if isinstance(v, UUID) else v
            for k, v in self.config.server.model_dump().items()
            if k not in filter_vars
        }
        assert self.status.runtime_path

        with open(
            os.path.join(
                self.status.runtime_path, self.config.variables_file_path
            ),
            "w",
        ) as fp:
            json.dump(vars, fp, indent=4)

        return vars

    def provision(self) -> None:
        """Provision the service."""
        super().provision()
        logger.info(
            f"Your ZenML server is now deployed with URL:\n"
            f"{self.get_server_url()}"
        )

    def get_server_url(self) -> str:
        """Returns the deployed ZenML server's URL.

        Returns:
            The URL of the deployed ZenML server.
        """
        return str(
            self.terraform_client.output(
                TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_URL, full_value=True
            )
        )

    def get_certificate(self) -> Optional[str]:
        """Returns the CA certificate configured for the ZenML server.

        Returns:
            The CA certificate configured for the ZenML server.
        """
        return cast(
            str,
            self.terraform_client.output(
                TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_CA_CRT, full_value=True
            ),
        )
get_certificate(self)

Returns the CA certificate configured for the ZenML server.

Returns:

Type Description
Optional[str]

The CA certificate configured for the ZenML server.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_certificate(self) -> Optional[str]:
    """Returns the CA certificate configured for the ZenML server.

    Returns:
        The CA certificate configured for the ZenML server.
    """
    return cast(
        str,
        self.terraform_client.output(
            TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_CA_CRT, full_value=True
        ),
    )
get_server_url(self)

Returns the deployed ZenML server's URL.

Returns:

Type Description
str

The URL of the deployed ZenML server.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_server_url(self) -> str:
    """Returns the deployed ZenML server's URL.

    Returns:
        The URL of the deployed ZenML server.
    """
    return str(
        self.terraform_client.output(
            TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_URL, full_value=True
        )
    )
get_service() classmethod

Load and return the terraform ZenML server service, if present.

Returns:

Type Description
Optional[TerraformZenServer]

The terraform ZenML server service or None, if the terraform server deployment is not found.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
@classmethod
def get_service(cls) -> Optional["TerraformZenServer"]:
    """Load and return the terraform ZenML server service, if present.

    Returns:
        The terraform ZenML server service or None, if the terraform server
        deployment is not found.
    """
    try:
        with open(TERRAFORM_ZENML_SERVER_CONFIG_FILENAME, "r") as f:
            return cast(
                TerraformZenServer, TerraformZenServer.from_json(f.read())
            )
    except FileNotFoundError:
        return None
get_vars(self)

Get variables as a dictionary.

Returns:

Type Description
Dict[str, Any]

A dictionary of variables to use for the Terraform deployment.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_vars(self) -> Dict[str, Any]:
    """Get variables as a dictionary.

    Returns:
        A dictionary of variables to use for the Terraform deployment.
    """
    # get the contents of the server deployment config as dict
    filter_vars = ["log_level", "provider"]
    # filter keys that are not modeled as terraform deployment vars
    vars = {
        k: str(v) if isinstance(v, UUID) else v
        for k, v in self.config.server.model_dump().items()
        if k not in filter_vars
    }
    assert self.status.runtime_path

    with open(
        os.path.join(
            self.status.runtime_path, self.config.variables_file_path
        ),
        "w",
    ) as fp:
        json.dump(vars, fp, indent=4)

    return vars
model_post_init(self, _ModelMetaclass__context)

We need to both initialize private attributes and call the user-defined model_post_init method.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
    """We need to both initialize private attributes and call the user-defined model_post_init
    method.
    """
    init_private_attributes(self, __context)
    original_model_post_init(self, __context)
provision(self)

Provision the service.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def provision(self) -> None:
    """Provision the service."""
    super().provision()
    logger.info(
        f"Your ZenML server is now deployed with URL:\n"
        f"{self.get_server_url()}"
    )
TerraformZenServerConfig (TerraformServiceConfig)

Terraform Zen server configuration.

Attributes:

Name Type Description
server TerraformServerDeploymentConfig

The deployment configuration.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class TerraformZenServerConfig(TerraformServiceConfig):
    """Terraform Zen server configuration.

    Attributes:
        server: The deployment configuration.
    """

    server: TerraformServerDeploymentConfig
    copy_terraform_files: bool = True
get_helm_chart_path()

Get the ZenML server helm chart path.

The ZenML server helm chart files are located in a folder relative to the zenml.zen_server.deploy Python module.

Returns:

Type Description
str

The helm chart path.

Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_helm_chart_path() -> str:
    """Get the ZenML server helm chart path.

    The ZenML server helm chart files are located in a folder relative to the
    `zenml.zen_server.deploy` Python module.

    Returns:
        The helm chart path.
    """
    import zenml.zen_server.deploy as deploy_module

    path = os.path.join(
        os.path.dirname(deploy_module.__file__),
        ZENML_HELM_CHART_SUBPATH,
    )
    return path

exceptions

REST API exception handling.

ErrorModel (BaseModel)

Base class for error responses.

Source code in zenml/zen_server/exceptions.py
class ErrorModel(BaseModel):
    """Base class for error responses."""

    detail: Optional[Any] = None

error_detail(error, exception_type=None)

Convert an Exception to API representation.

Parameters:

Name Type Description Default
error Exception

Exception to convert.

required
exception_type Optional[Type[Exception]]

Exception type to use in the error response instead of the type of the supplied exception. This is useful when the raised exception is a subclass of an exception type that is properly handled by the REST API.

None

Returns:

Type Description
List[str]

List of strings representing the error.

Source code in zenml/zen_server/exceptions.py
def error_detail(
    error: Exception, exception_type: Optional[Type[Exception]] = None
) -> List[str]:
    """Convert an Exception to API representation.

    Args:
        error: Exception to convert.
        exception_type: Exception type to use in the error response instead of
            the type of the supplied exception. This is useful when the raised
            exception is a subclass of an exception type that is properly
            handled by the REST API.

    Returns:
        List of strings representing the error.
    """
    class_name = (
        exception_type.__name__ if exception_type else type(error).__name__
    )
    return [class_name, str(error)]

exception_from_response(response)

Convert an error HTTP response to an exception.

Uses the REST_API_EXCEPTIONS list to determine the appropriate exception class to use based on the response status code and the exception class name embedded in the response body.

The last entry in the list of exceptions associated with a status code is used as a fallback if the exception class name in the response body is not found in the list.

Parameters:

Name Type Description Default
response Response

HTTP error response to convert.

required

Returns:

Type Description
Optional[Exception]

Exception with the appropriate type and arguments, or None if the response does not contain an error or the response cannot be unpacked into an exception.

Source code in zenml/zen_server/exceptions.py
def exception_from_response(
    response: requests.Response,
) -> Optional[Exception]:
    """Convert an error HTTP response to an exception.

    Uses the REST_API_EXCEPTIONS list to determine the appropriate exception
    class to use based on the response status code and the exception class name
    embedded in the response body.

    The last entry in the list of exceptions associated with a status code is
    used as a fallback if the exception class name in the response body is not
    found in the list.

    Args:
        response: HTTP error response to convert.

    Returns:
        Exception with the appropriate type and arguments, or None if the
        response does not contain an error or the response cannot be unpacked
        into an exception.
    """

    def unpack_exc() -> Tuple[Optional[str], str]:
        """Unpack the response body into an exception name and message.

        Returns:
            Tuple of exception name and message.
        """
        try:
            response_json = response.json()
        except requests.exceptions.JSONDecodeError:
            return None, response.text

        if isinstance(response_json, dict):
            detail = response_json.get("detail", response.text)
        else:
            detail = response_json

        # The detail can also be a single string
        if isinstance(detail, str):
            return None, detail

        # The detail should be a list of strings encoding the exception
        # class name and the exception message
        if not isinstance(detail, list):
            return None, response.text

        # First detail item is the exception class name
        if len(detail) < 1 or not isinstance(detail[0], str):
            return None, response.text

        # Remaining detail items are the exception arguments
        message = ": ".join([str(arg) for arg in detail[1:]])
        return detail[0], message

    exc_name, exc_msg = unpack_exc()
    default_exc: Optional[Type[Exception]] = None

    for exception, status_code in REST_API_EXCEPTIONS:
        if response.status_code != status_code:
            continue
        default_exc = exception
        if exc_name == exception.__name__:
            # An entry was found that is an exact match for both the status
            # code and the exception class name.
            break
    else:
        # The exception class name extracted from the response body was not
        # found in the list of exceptions associated with the status code, so
        # use the last entry as a fallback.
        if default_exc is None:
            return None

        exception = default_exc

    return exception(exc_msg)

http_exception_from_error(error)

Convert an Exception to a HTTP error response.

Uses the REST_API_EXCEPTIONS list to determine the appropriate status code associated with the exception type. The exception class name and arguments are embedded in the HTTP error response body.

The lookup uses the first occurrence of the exception type in the list. If the exception type is not found in the list, the lookup uses isinstance to determine the most specific exception type corresponding to the supplied exception. This allows users to call this method with exception types that are not directly listed in the REST_API_EXCEPTIONS list.

Parameters:

Name Type Description Default
error Exception

Exception to convert.

required

Returns:

Type Description
HTTPException

HTTPException with the appropriate status code and error detail.

Source code in zenml/zen_server/exceptions.py
def http_exception_from_error(error: Exception) -> "HTTPException":
    """Convert an Exception to a HTTP error response.

    Uses the REST_API_EXCEPTIONS list to determine the appropriate status code
    associated with the exception type. The exception class name and arguments
    are embedded in the HTTP error response body.

    The lookup uses the first occurrence of the exception type in the list. If
    the exception type is not found in the list, the lookup uses `isinstance`
    to determine the most specific exception type corresponding to the supplied
    exception. This allows users to call this method with exception types that
    are not directly listed in the REST_API_EXCEPTIONS list.

    Args:
        error: Exception to convert.

    Returns:
        HTTPException with the appropriate status code and error detail.
    """
    from fastapi import HTTPException

    status_code = 0
    matching_exception_type: Optional[Type[Exception]] = None

    for exception_type, exc_status_code in REST_API_EXCEPTIONS:
        if error.__class__ is exception_type:
            # Found an exact match
            matching_exception_type = exception_type
            status_code = exc_status_code
            break
        if isinstance(error, exception_type):
            # Found a matching exception
            if not matching_exception_type:
                # This is the first matching exception, so keep it
                matching_exception_type = exception_type
                status_code = exc_status_code
                continue

            # This is not the first matching exception, so check if it is more
            # specific than the previous matching exception
            if issubclass(
                exception_type,
                matching_exception_type,
            ):
                matching_exception_type = exception_type
                status_code = exc_status_code

    # When the matching exception is not found in the list, a 500 Internal
    # Server Error is returned
    status_code = status_code or 500
    matching_exception_type = matching_exception_type or RuntimeError

    return HTTPException(
        status_code=status_code,
        detail=error_detail(error, matching_exception_type),
    )

feature_gate special

endpoint_utils

All endpoint utils for the feature gate implementations.

check_entitlement(resource_type)

Queries the feature gate to see if the operation falls within the tenants entitlements.

Raises an exception if the user is not entitled to create an instance of the resource. Otherwise, simply returns.

Parameters:

Name Type Description Default
resource_type ResourceType

The type of resource to check for.

required
Source code in zenml/zen_server/feature_gate/endpoint_utils.py
def check_entitlement(resource_type: ResourceType) -> None:
    """Queries the feature gate to see if the operation falls within the tenants entitlements.

    Raises an exception if the user is not entitled to create an instance of the
    resource. Otherwise, simply returns.

    Args:
        resource_type: The type of resource to check for.
    """
    if not server_config().feature_gate_enabled:
        return
    return feature_gate().check_entitlement(resource=resource_type)
report_decrement(resource_type, resource_id)

Reports the deletion/deactivation of a feature/resource.

Parameters:

Name Type Description Default
resource_type ResourceType

The type of resource to report a decrement in count for.

required
resource_id UUID

ID of the resource that was deleted.

required
Source code in zenml/zen_server/feature_gate/endpoint_utils.py
def report_decrement(resource_type: ResourceType, resource_id: UUID) -> None:
    """Reports the deletion/deactivation of a feature/resource.

    Args:
        resource_type: The type of resource to report a decrement in count for.
        resource_id: ID of the resource that was deleted.
    """
    if not server_config().feature_gate_enabled:
        return
    feature_gate().report_event(
        resource=resource_type, resource_id=resource_id, is_decrement=True
    )
report_usage(resource_type, resource_id)

Reports the creation/usage of a feature/resource.

Parameters:

Name Type Description Default
resource_type ResourceType

The type of resource to report a usage for

required
resource_id UUID

ID of the resource that was created.

required
Source code in zenml/zen_server/feature_gate/endpoint_utils.py
def report_usage(resource_type: ResourceType, resource_id: UUID) -> None:
    """Reports the creation/usage of a feature/resource.

    Args:
        resource_type: The type of resource to report a usage for
        resource_id: ID of the resource that was created.
    """
    if not server_config().feature_gate_enabled:
        return
    feature_gate().report_event(
        resource=resource_type, resource_id=resource_id
    )

feature_gate_interface

Definition of the feature gate interface.

FeatureGateInterface (ABC)

RBAC interface definition.

Source code in zenml/zen_server/feature_gate/feature_gate_interface.py
class FeatureGateInterface(ABC):
    """RBAC interface definition."""

    @abstractmethod
    def check_entitlement(self, resource: ResourceType) -> None:
        """Checks if a user is entitled to create a resource.

        Args:
            resource: The resource the user wants to create

        Raises:
            UpgradeRequiredError in case a subscription limit is reached
        """

    @abstractmethod
    def report_event(
        self,
        resource: ResourceType,
        resource_id: UUID,
        is_decrement: bool = False,
    ) -> None:
        """Reports the usage of a feature to the aggregator backend.

        Args:
            resource: The resource the user created
            resource_id: ID of the resource that was created/deleted.
            is_decrement: In case this event reports an actual decrement of usage
        """
check_entitlement(self, resource)

Checks if a user is entitled to create a resource.

Parameters:

Name Type Description Default
resource ResourceType

The resource the user wants to create

required
Source code in zenml/zen_server/feature_gate/feature_gate_interface.py
@abstractmethod
def check_entitlement(self, resource: ResourceType) -> None:
    """Checks if a user is entitled to create a resource.

    Args:
        resource: The resource the user wants to create

    Raises:
        UpgradeRequiredError in case a subscription limit is reached
    """
report_event(self, resource, resource_id, is_decrement=False)

Reports the usage of a feature to the aggregator backend.

Parameters:

Name Type Description Default
resource ResourceType

The resource the user created

required
resource_id UUID

ID of the resource that was created/deleted.

required
is_decrement bool

In case this event reports an actual decrement of usage

False
Source code in zenml/zen_server/feature_gate/feature_gate_interface.py
@abstractmethod
def report_event(
    self,
    resource: ResourceType,
    resource_id: UUID,
    is_decrement: bool = False,
) -> None:
    """Reports the usage of a feature to the aggregator backend.

    Args:
        resource: The resource the user created
        resource_id: ID of the resource that was created/deleted.
        is_decrement: In case this event reports an actual decrement of usage
    """

zenml_cloud_feature_gate

ZenML Pro implementation of the feature gate.

RawUsageEvent (BaseModel)

Model for reporting raw usage of a feature.

In case of consumables the UsageReport allows the Pricing Backend to increment the usage per time-frame by 1.

Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
class RawUsageEvent(BaseModel):
    """Model for reporting raw usage of a feature.

    In case of consumables the UsageReport allows the Pricing Backend to
    increment the usage per time-frame by 1.
    """

    organization_id: str = Field(
        description="The organization that this usage can be attributed to.",
    )
    feature: ResourceType = Field(
        description="The feature whose usage is being reported.",
    )
    total: int = Field(
        description="The total amount of entities of this type."
    )
    metadata: Dict[str, Any] = Field(
        default={},
        description="Allows attaching additional metadata to events.",
    )
ZenMLCloudFeatureGateInterface (FeatureGateInterface)

ZenML Cloud Feature Gate implementation.

Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
class ZenMLCloudFeatureGateInterface(FeatureGateInterface):
    """ZenML Cloud Feature Gate implementation."""

    def __init__(self) -> None:
        """Initialize the object."""
        self._connection = cloud_connection()

    def check_entitlement(self, resource: ResourceType) -> None:
        """Checks if a user is entitled to create a resource.

        Args:
            resource: The resource the user wants to create

        Raises:
            SubscriptionUpgradeRequiredError: in case a subscription limit is reached
        """
        try:
            response = self._connection.get(
                endpoint=ENTITLEMENT_ENDPOINT + "/" + resource, params=None
            )
        except SubscriptionUpgradeRequiredError:
            raise SubscriptionUpgradeRequiredError(
                f"Your subscription reached its `{resource}` limit. Please "
                f"upgrade your subscription or reach out to us."
            )

        if response.status_code != 200:
            logger.warning(
                "Unexpected response status code from entitlement "
                f"endpoint: {response.status_code}. Message: "
                f"{response.json()}"
            )

    def report_event(
        self,
        resource: ResourceType,
        resource_id: UUID,
        is_decrement: bool = False,
    ) -> None:
        """Reports the usage of a feature to the aggregator backend.

        Args:
            resource: The resource the user created
            resource_id: ID of the resource that was created/deleted.
            is_decrement: In case this event reports an actual decrement of usage
        """
        data = RawUsageEvent(
            organization_id=ORGANIZATION_ID,
            feature=resource,
            total=1 if not is_decrement else -1,
            metadata={
                "tenant_id": str(server_config.external_server_id),
                "resource_id": str(resource_id),
            },
        ).model_dump()
        response = self._connection.post(
            endpoint=USAGE_EVENT_ENDPOINT, data=data
        )
        if response.status_code != 200:
            logger.error(
                "Usage report not accepted by upstream backend. "
                f"Status Code: {response.status_code}, Message: "
                f"{response.json()}."
            )
__init__(self) special

Initialize the object.

Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
def __init__(self) -> None:
    """Initialize the object."""
    self._connection = cloud_connection()
check_entitlement(self, resource)

Checks if a user is entitled to create a resource.

Parameters:

Name Type Description Default
resource ResourceType

The resource the user wants to create

required

Exceptions:

Type Description
SubscriptionUpgradeRequiredError

in case a subscription limit is reached

Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
def check_entitlement(self, resource: ResourceType) -> None:
    """Checks if a user is entitled to create a resource.

    Args:
        resource: The resource the user wants to create

    Raises:
        SubscriptionUpgradeRequiredError: in case a subscription limit is reached
    """
    try:
        response = self._connection.get(
            endpoint=ENTITLEMENT_ENDPOINT + "/" + resource, params=None
        )
    except SubscriptionUpgradeRequiredError:
        raise SubscriptionUpgradeRequiredError(
            f"Your subscription reached its `{resource}` limit. Please "
            f"upgrade your subscription or reach out to us."
        )

    if response.status_code != 200:
        logger.warning(
            "Unexpected response status code from entitlement "
            f"endpoint: {response.status_code}. Message: "
            f"{response.json()}"
        )
report_event(self, resource, resource_id, is_decrement=False)

Reports the usage of a feature to the aggregator backend.

Parameters:

Name Type Description Default
resource ResourceType

The resource the user created

required
resource_id UUID

ID of the resource that was created/deleted.

required
is_decrement bool

In case this event reports an actual decrement of usage

False
Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
def report_event(
    self,
    resource: ResourceType,
    resource_id: UUID,
    is_decrement: bool = False,
) -> None:
    """Reports the usage of a feature to the aggregator backend.

    Args:
        resource: The resource the user created
        resource_id: ID of the resource that was created/deleted.
        is_decrement: In case this event reports an actual decrement of usage
    """
    data = RawUsageEvent(
        organization_id=ORGANIZATION_ID,
        feature=resource,
        total=1 if not is_decrement else -1,
        metadata={
            "tenant_id": str(server_config.external_server_id),
            "resource_id": str(resource_id),
        },
    ).model_dump()
    response = self._connection.post(
        endpoint=USAGE_EVENT_ENDPOINT, data=data
    )
    if response.status_code != 200:
        logger.error(
            "Usage report not accepted by upstream backend. "
            f"Status Code: {response.status_code}, Message: "
            f"{response.json()}."
        )

jwt

Authentication module for ZenML server.

JWTToken (BaseModel)

Pydantic object representing a JWT token.

Attributes:

Name Type Description
user_id UUID

The id of the authenticated User.

device_id Optional[uuid.UUID]

The id of the authenticated device.

api_key_id Optional[uuid.UUID]

The id of the authenticated API key for which this token was issued.

pipeline_id Optional[uuid.UUID]

The id of the pipeline for which the token was issued.

schedule_id Optional[uuid.UUID]

The id of the schedule for which the token was issued.

claims Dict[str, Any]

The original token claims.

Source code in zenml/zen_server/jwt.py
class JWTToken(BaseModel):
    """Pydantic object representing a JWT token.

    Attributes:
        user_id: The id of the authenticated User.
        device_id: The id of the authenticated device.
        api_key_id: The id of the authenticated API key for which this token
            was issued.
        pipeline_id: The id of the pipeline for which the token was issued.
        schedule_id: The id of the schedule for which the token was issued.
        claims: The original token claims.
    """

    user_id: UUID
    device_id: Optional[UUID] = None
    api_key_id: Optional[UUID] = None
    pipeline_id: Optional[UUID] = None
    schedule_id: Optional[UUID] = None
    claims: Dict[str, Any] = {}

    @classmethod
    def decode_token(
        cls,
        token: str,
        verify: bool = True,
    ) -> "JWTToken":
        """Decodes a JWT access token.

        Decodes a JWT access token and returns a `JWTToken` object with the
        information retrieved from its subject claim.

        Args:
            token: The encoded JWT token.
            verify: Whether to verify the signature of the token.

        Returns:
            The decoded JWT access token.

        Raises:
            AuthorizationException: If the token is invalid.
        """
        config = server_config()

        try:
            claims_data = jwt.decode(
                token,
                config.jwt_secret_key,
                algorithms=[config.jwt_token_algorithm],
                audience=config.get_jwt_token_audience(),
                issuer=config.get_jwt_token_issuer(),
                verify=verify,
                leeway=timedelta(seconds=config.jwt_token_leeway_seconds),
            )
            claims = cast(Dict[str, Any], claims_data)
        except jwt.PyJWTError as e:
            raise AuthorizationException(f"Invalid JWT token: {e}") from e

        subject: str = claims.pop("sub", "")
        if not subject:
            raise AuthorizationException(
                "Invalid JWT token: the subject claim is missing"
            )

        try:
            user_id = UUID(subject)
        except ValueError:
            raise AuthorizationException(
                "Invalid JWT token: the subject claim is not a valid UUID"
            )

        device_id: Optional[UUID] = None
        if "device_id" in claims:
            try:
                device_id = UUID(claims.pop("device_id"))
            except ValueError:
                raise AuthorizationException(
                    "Invalid JWT token: the device_id claim is not a valid "
                    "UUID"
                )

        api_key_id: Optional[UUID] = None
        if "api_key_id" in claims:
            try:
                api_key_id = UUID(claims.pop("api_key_id"))
            except ValueError:
                raise AuthorizationException(
                    "Invalid JWT token: the api_key_id claim is not a valid "
                    "UUID"
                )

        pipeline_id: Optional[UUID] = None
        if "pipeline_id" in claims:
            try:
                pipeline_id = UUID(claims.pop("pipeline_id"))
            except ValueError:
                raise AuthorizationException(
                    "Invalid JWT token: the pipeline_id claim is not a valid "
                    "UUID"
                )

        schedule_id: Optional[UUID] = None
        if "schedule_id" in claims:
            try:
                schedule_id = UUID(claims.pop("schedule_id"))
            except ValueError:
                raise AuthorizationException(
                    "Invalid JWT token: the schedule_id claim is not a valid "
                    "UUID"
                )

        return JWTToken(
            user_id=user_id,
            device_id=device_id,
            api_key_id=api_key_id,
            pipeline_id=pipeline_id,
            schedule_id=schedule_id,
            claims=claims,
        )

    def encode(self, expires: Optional[datetime] = None) -> str:
        """Creates a JWT access token.

        Encodes, signs and returns a JWT access token.

        Args:
            expires: Datetime after which the token will expire. If not
                provided, the JWT token will not be set to expire.

        Returns:
            The generated access token.
        """
        config = server_config()

        claims: Dict[str, Any] = self.claims.copy()

        claims["sub"] = str(self.user_id)
        claims["iss"] = config.get_jwt_token_issuer()
        claims["aud"] = config.get_jwt_token_audience()

        if expires:
            claims["exp"] = expires
        else:
            claims.pop("exp", None)

        if self.device_id:
            claims["device_id"] = str(self.device_id)
        if self.api_key_id:
            claims["api_key_id"] = str(self.api_key_id)
        if self.pipeline_id:
            claims["pipeline_id"] = str(self.pipeline_id)
        if self.schedule_id:
            claims["schedule_id"] = str(self.schedule_id)

        return jwt.encode(
            claims,
            config.jwt_secret_key,
            algorithm=config.jwt_token_algorithm,
        )
decode_token(token, verify=True) classmethod

Decodes a JWT access token.

Decodes a JWT access token and returns a JWTToken object with the information retrieved from its subject claim.

Parameters:

Name Type Description Default
token str

The encoded JWT token.

required
verify bool

Whether to verify the signature of the token.

True

Returns:

Type Description
JWTToken

The decoded JWT access token.

Exceptions:

Type Description
AuthorizationException

If the token is invalid.

Source code in zenml/zen_server/jwt.py
@classmethod
def decode_token(
    cls,
    token: str,
    verify: bool = True,
) -> "JWTToken":
    """Decodes a JWT access token.

    Decodes a JWT access token and returns a `JWTToken` object with the
    information retrieved from its subject claim.

    Args:
        token: The encoded JWT token.
        verify: Whether to verify the signature of the token.

    Returns:
        The decoded JWT access token.

    Raises:
        AuthorizationException: If the token is invalid.
    """
    config = server_config()

    try:
        claims_data = jwt.decode(
            token,
            config.jwt_secret_key,
            algorithms=[config.jwt_token_algorithm],
            audience=config.get_jwt_token_audience(),
            issuer=config.get_jwt_token_issuer(),
            verify=verify,
            leeway=timedelta(seconds=config.jwt_token_leeway_seconds),
        )
        claims = cast(Dict[str, Any], claims_data)
    except jwt.PyJWTError as e:
        raise AuthorizationException(f"Invalid JWT token: {e}") from e

    subject: str = claims.pop("sub", "")
    if not subject:
        raise AuthorizationException(
            "Invalid JWT token: the subject claim is missing"
        )

    try:
        user_id = UUID(subject)
    except ValueError:
        raise AuthorizationException(
            "Invalid JWT token: the subject claim is not a valid UUID"
        )

    device_id: Optional[UUID] = None
    if "device_id" in claims:
        try:
            device_id = UUID(claims.pop("device_id"))
        except ValueError:
            raise AuthorizationException(
                "Invalid JWT token: the device_id claim is not a valid "
                "UUID"
            )

    api_key_id: Optional[UUID] = None
    if "api_key_id" in claims:
        try:
            api_key_id = UUID(claims.pop("api_key_id"))
        except ValueError:
            raise AuthorizationException(
                "Invalid JWT token: the api_key_id claim is not a valid "
                "UUID"
            )

    pipeline_id: Optional[UUID] = None
    if "pipeline_id" in claims:
        try:
            pipeline_id = UUID(claims.pop("pipeline_id"))
        except ValueError:
            raise AuthorizationException(
                "Invalid JWT token: the pipeline_id claim is not a valid "
                "UUID"
            )

    schedule_id: Optional[UUID] = None
    if "schedule_id" in claims:
        try:
            schedule_id = UUID(claims.pop("schedule_id"))
        except ValueError:
            raise AuthorizationException(
                "Invalid JWT token: the schedule_id claim is not a valid "
                "UUID"
            )

    return JWTToken(
        user_id=user_id,
        device_id=device_id,
        api_key_id=api_key_id,
        pipeline_id=pipeline_id,
        schedule_id=schedule_id,
        claims=claims,
    )
encode(self, expires=None)

Creates a JWT access token.

Encodes, signs and returns a JWT access token.

Parameters:

Name Type Description Default
expires Optional[datetime.datetime]

Datetime after which the token will expire. If not provided, the JWT token will not be set to expire.

None

Returns:

Type Description
str

The generated access token.

Source code in zenml/zen_server/jwt.py
def encode(self, expires: Optional[datetime] = None) -> str:
    """Creates a JWT access token.

    Encodes, signs and returns a JWT access token.

    Args:
        expires: Datetime after which the token will expire. If not
            provided, the JWT token will not be set to expire.

    Returns:
        The generated access token.
    """
    config = server_config()

    claims: Dict[str, Any] = self.claims.copy()

    claims["sub"] = str(self.user_id)
    claims["iss"] = config.get_jwt_token_issuer()
    claims["aud"] = config.get_jwt_token_audience()

    if expires:
        claims["exp"] = expires
    else:
        claims.pop("exp", None)

    if self.device_id:
        claims["device_id"] = str(self.device_id)
    if self.api_key_id:
        claims["api_key_id"] = str(self.api_key_id)
    if self.pipeline_id:
        claims["pipeline_id"] = str(self.pipeline_id)
    if self.schedule_id:
        claims["schedule_id"] = str(self.schedule_id)

    return jwt.encode(
        claims,
        config.jwt_secret_key,
        algorithm=config.jwt_token_algorithm,
    )

rate_limit

Rate limiting for the ZenML Server.

RequestLimiter

Simple in-memory rate limiter.

Source code in zenml/zen_server/rate_limit.py
class RequestLimiter:
    """Simple in-memory rate limiter."""

    def __init__(
        self,
        day_limit: Optional[int] = None,
        minute_limit: Optional[int] = None,
    ):
        """Initializes the limiter.

        Args:
            day_limit: The number of requests allowed per day.
            minute_limit: The number of requests allowed per minute.

        Raises:
            ValueError: If both day_limit and minute_limit are None.
        """
        self.limiting_enabled = server_config().rate_limit_enabled
        if not self.limiting_enabled:
            return
        if day_limit is None and minute_limit is None:
            raise ValueError("Pass either day or minuter limits, or both.")
        self.day_limit = day_limit
        self.minute_limit = minute_limit
        self.limiter: Dict[str, List[float]] = defaultdict(list)

    def hit_limiter(self, request: Request) -> None:
        """Increase the number of hits in the limiter.

        Args:
            request: Request object.

        Raises:
            HTTPException: If the request limit is exceeded.
        """
        if not self.limiting_enabled:
            return
        from fastapi import HTTPException

        requester = self._get_ipaddr(request)
        now = time.time()
        minute_ago = now - 60
        day_ago = now - 60 * 60 * 24
        self.limiter[requester].append(now)

        from bisect import bisect_left

        # remove failures older than a day
        older_index = bisect_left(self.limiter[requester], day_ago)
        self.limiter[requester] = self.limiter[requester][older_index:]

        if self.day_limit and len(self.limiter[requester]) > self.day_limit:
            raise HTTPException(
                status_code=429, detail="Daily request limit exceeded."
            )
        minute_requests = len(
            [
                limiter_hit
                for limiter_hit in self.limiter[requester][::-1]
                if limiter_hit >= minute_ago
            ]
        )
        if self.minute_limit and minute_requests > self.minute_limit:
            raise HTTPException(
                status_code=429, detail="Minute request limit exceeded."
            )

    def reset_limiter(self, request: Request) -> None:
        """Resets the limiter on successful request.

        Args:
            request: Request object.
        """
        if self.limiting_enabled:
            requester = self._get_ipaddr(request)
            if requester in self.limiter:
                del self.limiter[requester]

    def _get_ipaddr(self, request: Request) -> str:
        """Returns the IP address for the current request.

        Based on the X-Forwarded-For headers or client information.

        Args:
            request: The request object.

        Returns:
            The ip address for the current request (or 127.0.0.1 if none found).
        """
        if "X_FORWARDED_FOR" in request.headers:
            return request.headers["X_FORWARDED_FOR"]
        else:
            if not request.client or not request.client.host:
                return "127.0.0.1"

            return request.client.host

    @contextmanager
    def limit_failed_requests(
        self, request: Request
    ) -> Generator[None, Any, Any]:
        """Limits the number of failed requests.

        Args:
            request: Request object.

        Yields:
            None
        """
        self.hit_limiter(request)

        yield

        # if request was successful - reset limiter
        self.reset_limiter(request)
__init__(self, day_limit=None, minute_limit=None) special

Initializes the limiter.

Parameters:

Name Type Description Default
day_limit Optional[int]

The number of requests allowed per day.

None
minute_limit Optional[int]

The number of requests allowed per minute.

None

Exceptions:

Type Description
ValueError

If both day_limit and minute_limit are None.

Source code in zenml/zen_server/rate_limit.py
def __init__(
    self,
    day_limit: Optional[int] = None,
    minute_limit: Optional[int] = None,
):
    """Initializes the limiter.

    Args:
        day_limit: The number of requests allowed per day.
        minute_limit: The number of requests allowed per minute.

    Raises:
        ValueError: If both day_limit and minute_limit are None.
    """
    self.limiting_enabled = server_config().rate_limit_enabled
    if not self.limiting_enabled:
        return
    if day_limit is None and minute_limit is None:
        raise ValueError("Pass either day or minuter limits, or both.")
    self.day_limit = day_limit
    self.minute_limit = minute_limit
    self.limiter: Dict[str, List[float]] = defaultdict(list)
hit_limiter(self, request)

Increase the number of hits in the limiter.

Parameters:

Name Type Description Default
request Request

Request object.

required

Exceptions:

Type Description
HTTPException

If the request limit is exceeded.

Source code in zenml/zen_server/rate_limit.py
def hit_limiter(self, request: Request) -> None:
    """Increase the number of hits in the limiter.

    Args:
        request: Request object.

    Raises:
        HTTPException: If the request limit is exceeded.
    """
    if not self.limiting_enabled:
        return
    from fastapi import HTTPException

    requester = self._get_ipaddr(request)
    now = time.time()
    minute_ago = now - 60
    day_ago = now - 60 * 60 * 24
    self.limiter[requester].append(now)

    from bisect import bisect_left

    # remove failures older than a day
    older_index = bisect_left(self.limiter[requester], day_ago)
    self.limiter[requester] = self.limiter[requester][older_index:]

    if self.day_limit and len(self.limiter[requester]) > self.day_limit:
        raise HTTPException(
            status_code=429, detail="Daily request limit exceeded."
        )
    minute_requests = len(
        [
            limiter_hit
            for limiter_hit in self.limiter[requester][::-1]
            if limiter_hit >= minute_ago
        ]
    )
    if self.minute_limit and minute_requests > self.minute_limit:
        raise HTTPException(
            status_code=429, detail="Minute request limit exceeded."
        )
limit_failed_requests(self, request)

Limits the number of failed requests.

Parameters:

Name Type Description Default
request Request

Request object.

required

Yields:

Type Description
Generator[NoneType, Any, Any]

None

Source code in zenml/zen_server/rate_limit.py
@contextmanager
def limit_failed_requests(
    self, request: Request
) -> Generator[None, Any, Any]:
    """Limits the number of failed requests.

    Args:
        request: Request object.

    Yields:
        None
    """
    self.hit_limiter(request)

    yield

    # if request was successful - reset limiter
    self.reset_limiter(request)
reset_limiter(self, request)

Resets the limiter on successful request.

Parameters:

Name Type Description Default
request Request

Request object.

required
Source code in zenml/zen_server/rate_limit.py
def reset_limiter(self, request: Request) -> None:
    """Resets the limiter on successful request.

    Args:
        request: Request object.
    """
    if self.limiting_enabled:
        requester = self._get_ipaddr(request)
        if requester in self.limiter:
            del self.limiter[requester]

rate_limit_requests(day_limit=None, minute_limit=None)

Decorator to handle exceptions in the API.

Parameters:

Name Type Description Default
day_limit Optional[int]

Number of requests allowed per day.

None
minute_limit Optional[int]

Number of requests allowed per minute.

None

Returns:

Type Description
Callable[..., Any]

Decorated function.

Source code in zenml/zen_server/rate_limit.py
def rate_limit_requests(
    day_limit: Optional[int] = None,
    minute_limit: Optional[int] = None,
) -> Callable[..., Any]:
    """Decorator to handle exceptions in the API.

    Args:
        day_limit: Number of requests allowed per day.
        minute_limit: Number of requests allowed per minute.

    Returns:
        Decorated function.
    """
    limiter = RequestLimiter(day_limit=day_limit, minute_limit=minute_limit)

    def decorator(func: F) -> F:
        request_arg, request_kwarg = None, None
        parameters = inspect.signature(func).parameters
        for arg_num, arg_name in enumerate(parameters):
            if parameters[arg_name].annotation == Request:
                request_arg = arg_num
                request_kwarg = arg_name
                break
        if request_arg is None or request_kwarg is None:
            raise ValueError(
                "Rate limiting APIs must have argument of `Request` type."
            )

        @wraps(func)
        def decorated(
            *args: Any,
            **kwargs: Any,
        ) -> Any:
            if request_kwarg in kwargs:
                request = kwargs[request_kwarg]
            else:
                request = args[request_arg]
            with limiter.limit_failed_requests(request):
                return func(*args, **kwargs)

        return cast(F, decorated)

    return decorator

rbac special

RBAC definitions.

endpoint_utils

High-level helper functions to write endpoints with RBAC.

verify_permissions_and_create_entity(request_model, resource_type, create_method)

Verify permissions and create the entity if authorized.

Parameters:

Name Type Description Default
request_model ~AnyRequest

The entity request model.

required
resource_type ResourceType

The resource type of the entity to create.

required
create_method Callable[[~AnyRequest], ~AnyResponse]

The method to create the entity.

required

Exceptions:

Type Description
IllegalOperationError

If the request model has a different owner then the currently authenticated user.

Returns:

Type Description
~AnyResponse

A model of the created entity.

Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_create_entity(
    request_model: AnyRequest,
    resource_type: ResourceType,
    create_method: Callable[[AnyRequest], AnyResponse],
) -> AnyResponse:
    """Verify permissions and create the entity if authorized.

    Args:
        request_model: The entity request model.
        resource_type: The resource type of the entity to create.
        create_method: The method to create the entity.

    Raises:
        IllegalOperationError: If the request model has a different owner then
            the currently authenticated user.

    Returns:
        A model of the created entity.
    """
    if isinstance(request_model, UserScopedRequest):
        auth_context = get_auth_context()
        assert auth_context

        if request_model.user != auth_context.user.id:
            raise IllegalOperationError(
                f"Not allowed to create resource '{resource_type}' for a "
                "different user."
            )
    verify_permission(resource_type=resource_type, action=Action.CREATE)

    needs_usage_increment = (
        resource_type in REPORTABLE_RESOURCES
        and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING
    )
    if needs_usage_increment:
        check_entitlement(resource_type)

    created = create_method(request_model)

    if needs_usage_increment:
        report_usage(resource_type, resource_id=created.id)

    return created
verify_permissions_and_delete_entity(id, get_method, delete_method)

Verify permissions and delete an entity.

Parameters:

Name Type Description Default
id ~UUIDOrStr

The ID of the entity to delete.

required
get_method Callable[[~UUIDOrStr], ~AnyResponse]

The method to fetch the entity.

required
delete_method Callable[[~UUIDOrStr], NoneType]

The method to delete the entity.

required

Returns:

Type Description
~AnyResponse

The deleted entity.

Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_delete_entity(
    id: UUIDOrStr,
    get_method: Callable[[UUIDOrStr], AnyResponse],
    delete_method: Callable[[UUIDOrStr], None],
) -> AnyResponse:
    """Verify permissions and delete an entity.

    Args:
        id: The ID of the entity to delete.
        get_method: The method to fetch the entity.
        delete_method: The method to delete the entity.

    Returns:
        The deleted entity.
    """
    model = get_method(id)
    verify_permission_for_model(model, action=Action.DELETE)
    delete_method(model.id)

    return model
verify_permissions_and_get_entity(id, get_method, **get_method_kwargs)

Verify permissions and fetch an entity.

Parameters:

Name Type Description Default
id ~UUIDOrStr

The ID of the entity to fetch.

required
get_method Callable[[~UUIDOrStr], ~AnyResponse]

The method to fetch the entity.

required
get_method_kwargs Any

Keyword arguments to pass to the get method.

{}

Returns:

Type Description
~AnyResponse

A model of the fetched entity.

Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_get_entity(
    id: UUIDOrStr,
    get_method: Callable[[UUIDOrStr], AnyResponse],
    **get_method_kwargs: Any,
) -> AnyResponse:
    """Verify permissions and fetch an entity.

    Args:
        id: The ID of the entity to fetch.
        get_method: The method to fetch the entity.
        get_method_kwargs: Keyword arguments to pass to the get method.

    Returns:
        A model of the fetched entity.
    """
    model = get_method(id, **get_method_kwargs)
    verify_permission_for_model(model, action=Action.READ)
    return dehydrate_response_model(model)
verify_permissions_and_list_entities(filter_model, resource_type, list_method, **list_method_kwargs)

Verify permissions and list entities.

Parameters:

Name Type Description Default
filter_model ~AnyFilter

The entity filter model.

required
resource_type ResourceType

The resource type of the entities to list.

required
list_method Callable[[~AnyFilter], zenml.models.v2.base.page.Page[~AnyResponse]]

The method to list the entities.

required
list_method_kwargs Any

Keyword arguments to pass to the list method.

{}

Returns:

Type Description
Page[~AnyResponse]

A page of entity models.

Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_list_entities(
    filter_model: AnyFilter,
    resource_type: ResourceType,
    list_method: Callable[[AnyFilter], Page[AnyResponse]],
    **list_method_kwargs: Any,
) -> Page[AnyResponse]:
    """Verify permissions and list entities.

    Args:
        filter_model: The entity filter model.
        resource_type: The resource type of the entities to list.
        list_method: The method to list the entities.
        list_method_kwargs: Keyword arguments to pass to the list method.

    Returns:
        A page of entity models.
    """
    auth_context = get_auth_context()
    assert auth_context

    allowed_ids = get_allowed_resource_ids(resource_type=resource_type)
    filter_model.configure_rbac(
        authenticated_user_id=auth_context.user.id, id=allowed_ids
    )
    page = list_method(filter_model, **list_method_kwargs)
    return dehydrate_page(page)
verify_permissions_and_prune_entities(resource_type, prune_method, **kwargs)

Verify permissions and prune entities of certain type.

Parameters:

Name Type Description Default
resource_type ResourceType

The resource type of the entities to prune.

required
prune_method Callable[..., NoneType]

The method to prune the entities.

required
kwargs Any

Keyword arguments to pass to the prune method.

{}
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_prune_entities(
    resource_type: ResourceType,
    prune_method: Callable[..., None],
    **kwargs: Any,
) -> None:
    """Verify permissions and prune entities of certain type.

    Args:
        resource_type: The resource type of the entities to prune.
        prune_method: The method to prune the entities.
        kwargs: Keyword arguments to pass to the prune method.
    """
    verify_permission(resource_type=resource_type, action=Action.PRUNE)
    prune_method(**kwargs)
verify_permissions_and_update_entity(id, update_model, get_method, update_method)

Verify permissions and update an entity.

Parameters:

Name Type Description Default
id ~UUIDOrStr

The ID of the entity to update.

required
update_model ~AnyUpdate

The entity update model.

required
get_method Callable[[~UUIDOrStr], ~AnyResponse]

The method to fetch the entity.

required
update_method Callable[[~UUIDOrStr, ~AnyUpdate], ~AnyResponse]

The method to update the entity.

required

Returns:

Type Description
~AnyResponse

A model of the updated entity.

Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_update_entity(
    id: UUIDOrStr,
    update_model: AnyUpdate,
    get_method: Callable[[UUIDOrStr], AnyResponse],
    update_method: Callable[[UUIDOrStr, AnyUpdate], AnyResponse],
) -> AnyResponse:
    """Verify permissions and update an entity.

    Args:
        id: The ID of the entity to update.
        update_model: The entity update model.
        get_method: The method to fetch the entity.
        update_method: The method to update the entity.

    Returns:
        A model of the updated entity.
    """
    model = get_method(id)
    verify_permission_for_model(model, action=Action.UPDATE)
    updated_model = update_method(model.id, update_model)
    return dehydrate_response_model(updated_model)

models

RBAC model classes.

Action (StrEnum)

RBAC actions.

Source code in zenml/zen_server/rbac/models.py
class Action(StrEnum):
    """RBAC actions."""

    CREATE = "create"
    READ = "read"
    UPDATE = "update"
    DELETE = "delete"
    READ_SECRET_VALUE = "read_secret_value"
    PRUNE = "prune"

    # Service connectors
    CLIENT = "client"

    # Models
    PROMOTE = "promote"

    # Secrets
    BACKUP_RESTORE = "backup_restore"

    SHARE = "share"
Resource (BaseModel)

RBAC resource model.

Source code in zenml/zen_server/rbac/models.py
class Resource(BaseModel):
    """RBAC resource model."""

    type: str
    id: Optional[UUID] = None

    def __str__(self) -> str:
        """Convert to a string.

        Returns:
            Resource string representation.
        """
        representation = self.type
        if self.id:
            representation += f"/{self.id}"

        return representation

    model_config = ConfigDict(frozen=True)
__str__(self) special

Convert to a string.

Returns:

Type Description
str

Resource string representation.

Source code in zenml/zen_server/rbac/models.py
def __str__(self) -> str:
    """Convert to a string.

    Returns:
        Resource string representation.
    """
    representation = self.type
    if self.id:
        representation += f"/{self.id}"

    return representation
ResourceType (StrEnum)

Resource types of the server API.

Source code in zenml/zen_server/rbac/models.py
class ResourceType(StrEnum):
    """Resource types of the server API."""

    ACTION = "action"
    ARTIFACT = "artifact"
    ARTIFACT_VERSION = "artifact_version"
    CODE_REPOSITORY = "code_repository"
    EVENT_SOURCE = "event_source"
    FLAVOR = "flavor"
    MODEL = "model"
    MODEL_VERSION = "model_version"
    PIPELINE = "pipeline"
    PIPELINE_RUN = "pipeline_run"
    PIPELINE_DEPLOYMENT = "pipeline_deployment"
    PIPELINE_BUILD = "pipeline_build"
    RUN_TEMPLATE = "run_template"
    USER = "user"
    SERVICE = "service"
    RUN_METADATA = "run_metadata"
    SECRET = "secret"
    SERVICE_ACCOUNT = "service_account"
    SERVICE_CONNECTOR = "service_connector"
    STACK = "stack"
    STACK_COMPONENT = "stack_component"
    TAG = "tag"
    TRIGGER = "trigger"
    TRIGGER_EXECUTION = "trigger_execution"
    WORKSPACE = "workspace"

rbac_interface

RBAC interface definition.

RBACInterface (ABC)

RBAC interface definition.

Source code in zenml/zen_server/rbac/rbac_interface.py
class RBACInterface(ABC):
    """RBAC interface definition."""

    @abstractmethod
    def check_permissions(
        self, user: "UserResponse", resources: Set[Resource], action: Action
    ) -> Dict[Resource, bool]:
        """Checks if a user has permissions to perform an action on resources.

        Args:
            user: User which wants to access a resource.
            resources: The resources the user wants to access.
            action: The action that the user wants to perform on the resources.

        Returns:
            A dictionary mapping resources to a boolean which indicates whether
            the user has permissions to perform the action on that resource.
        """

    @abstractmethod
    def list_allowed_resource_ids(
        self, user: "UserResponse", resource: Resource, action: Action
    ) -> Tuple[bool, List[str]]:
        """Lists all resource IDs of a resource type that a user can access.

        Args:
            user: User which wants to access a resource.
            resource: The resource the user wants to access.
            action: The action that the user wants to perform on the resource.

        Returns:
            A tuple (full_resource_access, resource_ids).
            `full_resource_access` will be `True` if the user can perform the
            given action on any instance of the given resource type, `False`
            otherwise. If `full_resource_access` is `False`, `resource_ids`
            will contain the list of instance IDs that the user can perform
            the action on.
        """

    @abstractmethod
    def update_resource_membership(
        self, user: "UserResponse", resource: Resource, actions: List[Action]
    ) -> None:
        """Update the resource membership of a user.

        Args:
            user: User for which the resource membership should be updated.
            resource: The resource.
            actions: The actions that the user should be able to perform on the
                resource.
        """
check_permissions(self, user, resources, action)

Checks if a user has permissions to perform an action on resources.

Parameters:

Name Type Description Default
user UserResponse

User which wants to access a resource.

required
resources Set[zenml.zen_server.rbac.models.Resource]

The resources the user wants to access.

required
action Action

The action that the user wants to perform on the resources.

required

Returns:

Type Description
Dict[zenml.zen_server.rbac.models.Resource, bool]

A dictionary mapping resources to a boolean which indicates whether the user has permissions to perform the action on that resource.

Source code in zenml/zen_server/rbac/rbac_interface.py
@abstractmethod
def check_permissions(
    self, user: "UserResponse", resources: Set[Resource], action: Action
) -> Dict[Resource, bool]:
    """Checks if a user has permissions to perform an action on resources.

    Args:
        user: User which wants to access a resource.
        resources: The resources the user wants to access.
        action: The action that the user wants to perform on the resources.

    Returns:
        A dictionary mapping resources to a boolean which indicates whether
        the user has permissions to perform the action on that resource.
    """
list_allowed_resource_ids(self, user, resource, action)

Lists all resource IDs of a resource type that a user can access.

Parameters:

Name Type Description Default
user UserResponse

User which wants to access a resource.

required
resource Resource

The resource the user wants to access.

required
action Action

The action that the user wants to perform on the resource.

required

Returns:

Type Description
Tuple[bool, List[str]]

A tuple (full_resource_access, resource_ids). full_resource_access will be True if the user can perform the given action on any instance of the given resource type, False otherwise. If full_resource_access is False, resource_ids will contain the list of instance IDs that the user can perform the action on.

Source code in zenml/zen_server/rbac/rbac_interface.py
@abstractmethod
def list_allowed_resource_ids(
    self, user: "UserResponse", resource: Resource, action: Action
) -> Tuple[bool, List[str]]:
    """Lists all resource IDs of a resource type that a user can access.

    Args:
        user: User which wants to access a resource.
        resource: The resource the user wants to access.
        action: The action that the user wants to perform on the resource.

    Returns:
        A tuple (full_resource_access, resource_ids).
        `full_resource_access` will be `True` if the user can perform the
        given action on any instance of the given resource type, `False`
        otherwise. If `full_resource_access` is `False`, `resource_ids`
        will contain the list of instance IDs that the user can perform
        the action on.
    """
update_resource_membership(self, user, resource, actions)

Update the resource membership of a user.

Parameters:

Name Type Description Default
user UserResponse

User for which the resource membership should be updated.

required
resource Resource

The resource.

required
actions List[zenml.zen_server.rbac.models.Action]

The actions that the user should be able to perform on the resource.

required
Source code in zenml/zen_server/rbac/rbac_interface.py
@abstractmethod
def update_resource_membership(
    self, user: "UserResponse", resource: Resource, actions: List[Action]
) -> None:
    """Update the resource membership of a user.

    Args:
        user: User for which the resource membership should be updated.
        resource: The resource.
        actions: The actions that the user should be able to perform on the
            resource.
    """

utils

RBAC utility functions.

batch_verify_permissions(resources, action)

Batch permission verification.

Parameters:

Name Type Description Default
resources Set[zenml.zen_server.rbac.models.Resource]

The resources the user wants to perform the action on.

required
action Action

The action the user wants to perform.

required

Exceptions:

Type Description
IllegalOperationError

If the user is not allowed to perform the action.

RuntimeError

If the permission verification failed unexpectedly.

Source code in zenml/zen_server/rbac/utils.py
def batch_verify_permissions(
    resources: Set[Resource],
    action: Action,
) -> None:
    """Batch permission verification.

    Args:
        resources: The resources the user wants to perform the action on.
        action: The action the user wants to perform.

    Raises:
        IllegalOperationError: If the user is not allowed to perform the action.
        RuntimeError: If the permission verification failed unexpectedly.
    """
    if not server_config().rbac_enabled:
        return

    auth_context = get_auth_context()
    assert auth_context

    permissions = rbac().check_permissions(
        user=auth_context.user, resources=resources, action=action
    )

    for resource in resources:
        if resource not in permissions:
            # This should never happen if the RBAC implementation is working
            # correctly
            raise RuntimeError(
                f"Failed to verify permissions to {action.upper()} resource "
                f"'{resource}'."
            )

        if not permissions[resource]:
            raise IllegalOperationError(
                message=f"Insufficient permissions to {action.upper()} "
                f"resource '{resource}'.",
            )
batch_verify_permissions_for_models(models, action)

Batch permission verification for models.

Parameters:

Name Type Description Default
models Sequence[~AnyResponse]

The models the user wants to perform the action on.

required
action Action

The action the user wants to perform.

required
Source code in zenml/zen_server/rbac/utils.py
def batch_verify_permissions_for_models(
    models: Sequence[AnyResponse],
    action: Action,
) -> None:
    """Batch permission verification for models.

    Args:
        models: The models the user wants to perform the action on.
        action: The action the user wants to perform.
    """
    if not server_config().rbac_enabled:
        return

    resources = set()
    for model in models:
        if is_owned_by_authenticated_user(model):
            # The model owner always has permissions
            continue

        permission_model = get_surrogate_permission_model_for_model(
            model, action=action
        )

        if resource := get_resource_for_model(permission_model):
            resources.add(resource)

    batch_verify_permissions(resources=resources, action=action)
dehydrate_page(page)

Dehydrate all items of a page.

Parameters:

Name Type Description Default
page Page[~AnyResponse]

The page to dehydrate.

required

Returns:

Type Description
Page[~AnyResponse]

The page with (potentially) dehydrated items.

Source code in zenml/zen_server/rbac/utils.py
def dehydrate_page(page: Page[AnyResponse]) -> Page[AnyResponse]:
    """Dehydrate all items of a page.

    Args:
        page: The page to dehydrate.

    Returns:
        The page with (potentially) dehydrated items.
    """
    if not server_config().rbac_enabled:
        return page

    auth_context = get_auth_context()
    assert auth_context

    resource_list = [get_subresources_for_model(item) for item in page.items]
    resources = set.union(*resource_list) if resource_list else set()
    permissions = rbac().check_permissions(
        user=auth_context.user, resources=resources, action=Action.READ
    )

    new_items = [
        dehydrate_response_model(item, permissions=permissions)
        for item in page.items
    ]

    return page.model_copy(update={"items": new_items})
dehydrate_response_model(model, permissions=None)

Dehydrate a model if necessary.

Parameters:

Name Type Description Default
model ~AnyModel

The model to dehydrate.

required
permissions Optional[Dict[zenml.zen_server.rbac.models.Resource, bool]]

Prefetched permissions that will be used to check whether sub-models will be included in the model or not. If a sub-model refers to a resource which is not included in this dictionary, the permissions will be checked with the RBAC component.

None

Returns:

Type Description
~AnyModel

The (potentially) dehydrated model.

Source code in zenml/zen_server/rbac/utils.py
def dehydrate_response_model(
    model: AnyModel, permissions: Optional[Dict[Resource, bool]] = None
) -> AnyModel:
    """Dehydrate a model if necessary.

    Args:
        model: The model to dehydrate.
        permissions: Prefetched permissions that will be used to check whether
            sub-models will be included in the model or not. If a sub-model
            refers to a resource which is not included in this dictionary, the
            permissions will be checked with the RBAC component.

    Returns:
        The (potentially) dehydrated model.
    """
    if not server_config().rbac_enabled:
        return model

    if not permissions:
        auth_context = get_auth_context()
        assert auth_context

        resources = get_subresources_for_model(model)
        permissions = rbac().check_permissions(
            user=auth_context.user, resources=resources, action=Action.READ
        )

    dehydrated_values = {}
    # See `get_subresources_for_model(...)` for a detailed explanation why we
    # need to use `model.__iter__()` here
    for key, value in model.__iter__():
        dehydrated_values[key] = _dehydrate_value(
            value, permissions=permissions
        )

    return type(model).model_validate(dehydrated_values)
get_allowed_resource_ids(resource_type, action=<Action.READ: 'read'>)

Get all resource IDs of a resource type that a user can access.

Parameters:

Name Type Description Default
resource_type str

The resource type.

required
action Action

The action the user wants to perform on the resource.

<Action.READ: 'read'>

Returns:

Type Description
Optional[Set[uuid.UUID]]

A list of resource IDs or None if the user has full access to the all instances of the resource.

Source code in zenml/zen_server/rbac/utils.py
def get_allowed_resource_ids(
    resource_type: str,
    action: Action = Action.READ,
) -> Optional[Set[UUID]]:
    """Get all resource IDs of a resource type that a user can access.

    Args:
        resource_type: The resource type.
        action: The action the user wants to perform on the resource.

    Returns:
        A list of resource IDs or `None` if the user has full access to the
        all instances of the resource.
    """
    if not server_config().rbac_enabled:
        return None

    auth_context = get_auth_context()
    assert auth_context

    (
        has_full_resource_access,
        allowed_ids,
    ) = rbac().list_allowed_resource_ids(
        user=auth_context.user,
        resource=Resource(type=resource_type),
        action=action,
    )

    if has_full_resource_access:
        return None

    return {UUID(id) for id in allowed_ids}
get_permission_denied_model(model)

Get a model to return in case of missing read permissions.

Parameters:

Name Type Description Default
model ~AnyResponse

The original model.

required

Returns:

Type Description
~AnyResponse

The permission denied model.

Source code in zenml/zen_server/rbac/utils.py
def get_permission_denied_model(model: AnyResponse) -> AnyResponse:
    """Get a model to return in case of missing read permissions.

    Args:
        model: The original model.

    Returns:
        The permission denied model.
    """
    return model.model_copy(
        update={
            "body": None,
            "metadata": None,
            "resources": None,
            "permission_denied": True,
        }
    )
get_resource_for_model(model)

Get the resource associated with a model object.

Parameters:

Name Type Description Default
model ~AnyResponse

The model for which to get the resource.

required

Returns:

Type Description
Optional[zenml.zen_server.rbac.models.Resource]

The resource associated with the model, or None if the model is not associated with any resource type.

Source code in zenml/zen_server/rbac/utils.py
def get_resource_for_model(model: AnyResponse) -> Optional[Resource]:
    """Get the resource associated with a model object.

    Args:
        model: The model for which to get the resource.

    Returns:
        The resource associated with the model, or `None` if the model
        is not associated with any resource type.
    """
    resource_type = get_resource_type_for_model(model)
    if not resource_type:
        # This model is not tied to any RBAC resource type
        return None

    return Resource(type=resource_type, id=model.id)
get_resource_type_for_model(model)

Get the resource type associated with a model object.

Parameters:

Name Type Description Default
model ~AnyResponse

The model for which to get the resource type.

required

Returns:

Type Description
Optional[zenml.zen_server.rbac.models.ResourceType]

The resource type associated with the model, or None if the model is not associated with any resource type.

Source code in zenml/zen_server/rbac/utils.py
def get_resource_type_for_model(
    model: AnyResponse,
) -> Optional[ResourceType]:
    """Get the resource type associated with a model object.

    Args:
        model: The model for which to get the resource type.

    Returns:
        The resource type associated with the model, or `None` if the model
        is not associated with any resource type.
    """
    from zenml.models import (
        ActionResponse,
        ArtifactResponse,
        ArtifactVersionResponse,
        CodeRepositoryResponse,
        ComponentResponse,
        EventSourceResponse,
        FlavorResponse,
        ModelResponse,
        ModelVersionResponse,
        PipelineBuildResponse,
        PipelineDeploymentResponse,
        PipelineResponse,
        PipelineRunResponse,
        RunMetadataResponse,
        RunTemplateResponse,
        SecretResponse,
        ServiceAccountResponse,
        ServiceConnectorResponse,
        ServiceResponse,
        StackResponse,
        TagResponse,
        TriggerExecutionResponse,
        TriggerResponse,
        UserResponse,
        WorkspaceResponse,
    )

    mapping: Dict[
        Any,
        ResourceType,
    ] = {
        ActionResponse: ResourceType.ACTION,
        EventSourceResponse: ResourceType.EVENT_SOURCE,
        FlavorResponse: ResourceType.FLAVOR,
        ServiceConnectorResponse: ResourceType.SERVICE_CONNECTOR,
        ComponentResponse: ResourceType.STACK_COMPONENT,
        StackResponse: ResourceType.STACK,
        PipelineResponse: ResourceType.PIPELINE,
        CodeRepositoryResponse: ResourceType.CODE_REPOSITORY,
        SecretResponse: ResourceType.SECRET,
        ModelResponse: ResourceType.MODEL,
        ModelVersionResponse: ResourceType.MODEL_VERSION,
        ArtifactResponse: ResourceType.ARTIFACT,
        ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION,
        WorkspaceResponse: ResourceType.WORKSPACE,
        UserResponse: ResourceType.USER,
        RunMetadataResponse: ResourceType.RUN_METADATA,
        PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT,
        PipelineBuildResponse: ResourceType.PIPELINE_BUILD,
        PipelineRunResponse: ResourceType.PIPELINE_RUN,
        RunTemplateResponse: ResourceType.RUN_TEMPLATE,
        TagResponse: ResourceType.