Zen Server
zenml.zen_server
special
ZenML Server Implementation.
The ZenML Server is a centralized service meant for use in a collaborative setting in which stacks, stack components, flavors, pipeline and pipeline runs can be shared over the network with other users.
You can use the zenml server up
command to spin up ZenML server instances
that are either running locally as daemon processes or docker containers, or
to deploy a ZenML server remotely on a managed cloud platform. The other CLI
commands in the same zenml server
group can be used to manage the server
instances deployed from your local machine.
To connect the local ZenML client to one of the managed ZenML servers, call
zenml server connect
with the name of the server you want to connect to.
auth
Authentication module for ZenML server.
AuthContext (BaseModel)
The authentication context.
Source code in zenml/zen_server/auth.py
class AuthContext(BaseModel):
"""The authentication context."""
user: UserResponse
access_token: Optional[JWTToken] = None
encoded_access_token: Optional[str] = None
device: Optional[OAuthDeviceInternalResponse] = None
api_key: Optional[APIKeyInternalResponse] = None
CookieOAuth2TokenBearer (OAuth2PasswordBearer)
OAuth2 token bearer authentication scheme that uses a cookie.
Source code in zenml/zen_server/auth.py
class CookieOAuth2TokenBearer(OAuth2PasswordBearer):
"""OAuth2 token bearer authentication scheme that uses a cookie."""
async def __call__(self, request: Request) -> Optional[str]:
"""Extract the bearer token from the request.
Args:
request: The request.
Returns:
The bearer token extracted from the request cookie or header.
"""
# First, try to get the token from the cookie
authorization = request.cookies.get(
server_config().get_auth_cookie_name()
)
if authorization:
logger.info("Got token from cookie")
return authorization
# If the token is not present in the cookie, try to get it from the
# Authorization header
return await super().__call__(request)
__call__(self, request)
async
special
Extract the bearer token from the request.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The bearer token extracted from the request cookie or header. |
Source code in zenml/zen_server/auth.py
async def __call__(self, request: Request) -> Optional[str]:
"""Extract the bearer token from the request.
Args:
request: The request.
Returns:
The bearer token extracted from the request cookie or header.
"""
# First, try to get the token from the cookie
authorization = request.cookies.get(
server_config().get_auth_cookie_name()
)
if authorization:
logger.info("Got token from cookie")
return authorization
# If the token is not present in the cookie, try to get it from the
# Authorization header
return await super().__call__(request)
authenticate_api_key(api_key)
Implement service account API key authentication.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
api_key |
str |
The service account API key. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated service account. |
Exceptions:
Type | Description |
---|---|
CredentialsNotValid |
If the service account could not be authorized. |
Source code in zenml/zen_server/auth.py
def authenticate_api_key(
api_key: str,
) -> AuthContext:
"""Implement service account API key authentication.
Args:
api_key: The service account API key.
Returns:
The authentication context reflecting the authenticated service account.
Raises:
CredentialsNotValid: If the service account could not be authorized.
"""
try:
decoded_api_key = APIKey.decode_api_key(api_key)
except ValueError:
error = "Authentication error: error decoding API key"
logger.exception(error)
raise CredentialsNotValid(error)
internal_api_key = _fetch_and_verify_api_key(
api_key_id=decoded_api_key.id, key_to_verify=decoded_api_key.key
)
# For now, a lot of code still relies on the active user in the auth
# context being a UserResponse object, which is a superset of the
# ServiceAccountResponse object. So we need to convert the service
# account to a user here.
user_model = internal_api_key.service_account.to_user_model()
return AuthContext(user=user_model, api_key=internal_api_key)
authenticate_credentials(user_name_or_id=None, password=None, access_token=None, activation_token=None)
Verify if user authentication credentials are valid.
This function can be used to validate all supplied user credentials to cover a range of possibilities:
- username only - only when the no-auth scheme is used
- username+password - for basic HTTP authentication or the OAuth2 password grant
- access token (with embedded user id) - after successful authentication using one of the supported grants
- username+activation token - for user activation
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
The username or user ID. |
None |
password |
Optional[str] |
The password. |
None |
access_token |
Optional[str] |
The access token. |
None |
activation_token |
Optional[str] |
The activation token. |
None |
Returns:
Type | Description |
---|---|
AuthContext |
The authenticated account details. |
Exceptions:
Type | Description |
---|---|
CredentialsNotValid |
If the credentials are invalid. |
Source code in zenml/zen_server/auth.py
def authenticate_credentials(
user_name_or_id: Optional[Union[str, UUID]] = None,
password: Optional[str] = None,
access_token: Optional[str] = None,
activation_token: Optional[str] = None,
) -> AuthContext:
"""Verify if user authentication credentials are valid.
This function can be used to validate all supplied user credentials to
cover a range of possibilities:
* username only - only when the no-auth scheme is used
* username+password - for basic HTTP authentication or the OAuth2 password
grant
* access token (with embedded user id) - after successful authentication
using one of the supported grants
* username+activation token - for user activation
Args:
user_name_or_id: The username or user ID.
password: The password.
access_token: The access token.
activation_token: The activation token.
Returns:
The authenticated account details.
Raises:
CredentialsNotValid: If the credentials are invalid.
"""
user: Optional[UserAuthModel] = None
auth_context: Optional[AuthContext] = None
if user_name_or_id:
try:
# NOTE: this method will not return a user if the user name or ID
# identifies a service account instead of a regular user. This
# is intentional because service accounts are not allowed to
# be used to authenticate to the API using a username and password,
# or an activation token.
user = zen_store().get_auth_user(user_name_or_id)
user_model = zen_store().get_user(
user_name_or_id=user_name_or_id, include_private=True
)
auth_context = AuthContext(user=user_model)
except KeyError:
# even when the user does not exist, we still want to execute the
# password/token verification to protect against response discrepancy
# attacks (https://cwe.mitre.org/data/definitions/204.html)
logger.exception(
f"Authentication error: error retrieving account "
f"{user_name_or_id}"
)
pass
if password is not None:
if not UserAuthModel.verify_password(password, user):
error = "Authentication error: invalid username or password"
logger.error(error)
raise CredentialsNotValid(error)
if user and not user.active:
error = f"Authentication error: user {user.name} is not active"
logger.error(error)
raise CredentialsNotValid(error)
elif activation_token is not None:
if not UserAuthModel.verify_activation_token(activation_token, user):
error = (
f"Authentication error: invalid activation token for user "
f"{user_name_or_id}"
)
logger.error(error)
raise CredentialsNotValid(error)
elif access_token is not None:
try:
decoded_token = JWTToken.decode_token(
token=access_token,
)
except CredentialsNotValid as e:
error = f"Authentication error: error decoding access token: {e}."
logger.exception(error)
raise CredentialsNotValid(error)
try:
user_model = zen_store().get_user(
user_name_or_id=decoded_token.user_id, include_private=True
)
except KeyError:
error = (
f"Authentication error: error retrieving token account "
f"{decoded_token.user_id}"
)
logger.error(error)
raise CredentialsNotValid(error)
if not user_model.active:
error = (
f"Authentication error: account {user_model.name} is not "
f"active"
)
logger.error(error)
raise CredentialsNotValid(error)
api_key_model: Optional[APIKeyInternalResponse] = None
if decoded_token.api_key_id:
# The API token was generated from an API key. We still have to
# verify if the API key hasn't been deactivated or deleted in the
# meantime.
api_key_model = _fetch_and_verify_api_key(decoded_token.api_key_id)
device_model: Optional[OAuthDeviceInternalResponse] = None
if decoded_token.device_id:
# Access tokens that have been issued for a device are only valid
# for that device, so we need to check if the device ID matches any
# of the valid devices in the database.
try:
device_model = zen_store().get_internal_authorized_device(
device_id=decoded_token.device_id
)
except KeyError:
error = (
f"Authentication error: error retrieving token device "
f"{decoded_token.device_id}"
)
logger.error(error)
raise CredentialsNotValid(error)
if (
device_model.user is None
or device_model.user.id != user_model.id
):
error = (
f"Authentication error: device {decoded_token.device_id} "
f"does not belong to user {user_model.name}"
)
logger.error(error)
raise CredentialsNotValid(error)
if device_model.status != OAuthDeviceStatus.ACTIVE:
error = (
f"Authentication error: device {decoded_token.device_id} "
f"is not active"
)
logger.error(error)
raise CredentialsNotValid(error)
if (
device_model.expires
and datetime.utcnow() >= device_model.expires
):
error = (
f"Authentication error: device {decoded_token.device_id} "
"has expired"
)
logger.error(error)
raise CredentialsNotValid(error)
zen_store().update_internal_authorized_device(
device_id=device_model.id,
update=OAuthDeviceInternalUpdate(
update_last_login=True,
),
)
if decoded_token.schedule_id:
# If the token contains a schedule ID, we need to check if the
# schedule still exists in the database. We use a cached version
# of the schedule active status to avoid unnecessary database
# queries.
@cache_result(expiry=30)
def get_schedule_active(schedule_id: UUID) -> Optional[bool]:
"""Get the active status of a schedule.
Args:
schedule_id: The schedule ID.
Returns:
The schedule active status or None if the schedule does not
exist.
"""
try:
schedule = zen_store().get_schedule(
schedule_id, hydrate=False
)
except KeyError:
return False
return schedule.active
schedule_active = get_schedule_active(decoded_token.schedule_id)
if schedule_active is None:
error = (
f"Authentication error: error retrieving token schedule "
f"{decoded_token.schedule_id}"
)
logger.error(error)
raise CredentialsNotValid(error)
if not schedule_active:
error = (
f"Authentication error: schedule {decoded_token.schedule_id} "
"is not active"
)
logger.error(error)
raise CredentialsNotValid(error)
if decoded_token.pipeline_run_id:
# If the token contains a pipeline run ID, we need to check if the
# pipeline run exists in the database and the pipeline run has
# not concluded. We use a cached version of the pipeline run status
# to avoid unnecessary database queries.
@cache_result(expiry=30)
def get_pipeline_run_status(
pipeline_run_id: UUID,
) -> Optional[ExecutionStatus]:
"""Get the status of a pipeline run.
Args:
pipeline_run_id: The pipeline run ID.
Returns:
The pipeline run status or None if the pipeline run does not
exist.
"""
try:
pipeline_run = zen_store().get_run(
pipeline_run_id, hydrate=False
)
except KeyError:
return None
return pipeline_run.status
pipeline_run_status = get_pipeline_run_status(
decoded_token.pipeline_run_id
)
if pipeline_run_status is None:
error = (
f"Authentication error: error retrieving token pipeline run "
f"{decoded_token.pipeline_run_id}"
)
logger.error(error)
raise CredentialsNotValid(error)
if pipeline_run_status.is_finished:
error = (
f"The execution of pipeline run "
f"{decoded_token.pipeline_run_id} has already concluded and "
"API tokens scoped to it are no longer valid."
)
logger.error(error)
raise CredentialsNotValid(error)
if decoded_token.step_run_id:
# If the token contains a step run ID, we need to check if the
# step run exists in the database and the step run has not concluded.
# We use a cached version of the step run status to avoid unnecessary
# database queries.
@cache_result(expiry=30)
def get_step_run_status(
step_run_id: UUID,
) -> Optional[ExecutionStatus]:
"""Get the status of a step run.
Args:
step_run_id: The step run ID.
Returns:
The step run status or None if the step run does not exist.
"""
try:
step_run = zen_store().get_run_step(
step_run_id, hydrate=False
)
except KeyError:
return None
return step_run.status
step_run_status = get_step_run_status(decoded_token.step_run_id)
if step_run_status is None:
error = (
f"Authentication error: error retrieving token step run "
f"{decoded_token.step_run_id}"
)
logger.error(error)
raise CredentialsNotValid(error)
if step_run_status.is_finished:
error = (
f"The execution of step run "
f"{decoded_token.step_run_id} has already concluded and "
"API tokens scoped to it are no longer valid."
)
logger.error(error)
raise CredentialsNotValid(error)
auth_context = AuthContext(
user=user_model,
access_token=decoded_token,
encoded_access_token=access_token,
device=device_model,
api_key=api_key_model,
)
else:
# IMPORTANT: the ONLY way we allow the authentication process to
# continue without any credentials (i.e. no password, activation
# token or access token) is if authentication is explicitly disabled
# by setting the auth_scheme to NO_AUTH.
if server_config().auth_scheme != AuthScheme.NO_AUTH:
error = "Authentication error: no credentials provided"
logger.error(error)
raise CredentialsNotValid(error)
if not auth_context:
error = "Authentication error: invalid credentials"
logger.error(error)
raise CredentialsNotValid(error)
return auth_context
authenticate_device(client_id, device_code)
Verify if device authorization credentials are valid.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
client_id |
UUID |
The OAuth2 client ID. |
required |
device_code |
str |
The device code. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authenticated account details. |
Exceptions:
Type | Description |
---|---|
OAuthError |
If the device authorization credentials are invalid. |
Source code in zenml/zen_server/auth.py
def authenticate_device(client_id: UUID, device_code: str) -> AuthContext:
"""Verify if device authorization credentials are valid.
Args:
client_id: The OAuth2 client ID.
device_code: The device code.
Returns:
The authenticated account details.
Raises:
OAuthError: If the device authorization credentials are invalid.
"""
# This is the part of the OAuth2 device code grant flow where a client
# device is continuously polling the server to check if the user has
# authorized a device. The following needs to happen to successfully
# authenticate the device and return a valid access token:
#
# 1. the device code and client ID must match a device in the DB
# 2. the device must be in the VERIFIED state, meaning that the user
# has successfully authorized the device via the user code but the
# device client hasn't yet fetched the associated API access token yet.
# 3. the device must not be expired
config = server_config()
store = zen_store()
try:
device_model = store.get_internal_authorized_device(
client_id=client_id
)
except KeyError:
error = (
f"Authentication error: error retrieving device with client ID "
f"{client_id}"
)
logger.error(error)
raise OAuthError(
error="invalid_client",
error_description=error,
)
if device_model.status != OAuthDeviceStatus.VERIFIED:
error = (
f"Authentication error: device with client ID {client_id} is "
f"{device_model.status.value}."
)
logger.error(error)
if device_model.status == OAuthDeviceStatus.PENDING:
oauth_error = "authorization_pending"
elif device_model.status == OAuthDeviceStatus.LOCKED:
oauth_error = "access_denied"
else:
oauth_error = "expired_token"
raise OAuthError(
error=oauth_error,
error_description=error,
)
if device_model.expires and datetime.utcnow() >= device_model.expires:
error = (
f"Authentication error: device for client ID {client_id} has "
"expired"
)
logger.error(error)
raise OAuthError(
error="expired_token",
error_description=error,
)
# Check the device code
if not device_model.verify_device_code(device_code):
# If the device code is invalid, increment the failed auth attempts
# counter and lock the device if the maximum number of failed auth
# attempts has been reached.
failed_auth_attempts = device_model.failed_auth_attempts + 1
update = OAuthDeviceInternalUpdate(
failed_auth_attempts=failed_auth_attempts
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
update.locked = True
store.update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
error = (
f"Authentication error: device for client ID {client_id} "
"has been locked due to too many failed authentication "
"attempts."
)
else:
error = (
f"Authentication error: device for client ID {client_id} "
"has an invalid device code."
)
logger.error(error)
raise OAuthError(
error="access_denied",
error_description=error,
)
# The device is valid, so we can return the user associated with it.
# This is the one and only time we return an AuthContext authorized by
# a device code in order to be exchanged for an access token. Subsequent
# requests to the API will be authenticated using the access token.
#
# Update the device state to ACTIVE and set an expiration date for it
# past which it can no longer be used for authentication. The expiration
# date also determines the expiration date of the access token issued
# for this device.
expires_in: int = 0
if config.jwt_token_expire_minutes:
if device_model.trusted_device:
expires_in = config.trusted_device_expiration_minutes or 0
else:
expires_in = config.device_expiration_minutes or 0
update = OAuthDeviceInternalUpdate(
status=OAuthDeviceStatus.ACTIVE,
expires_in=expires_in * 60,
)
device_model = zen_store().update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
# This can never happen because the VERIFIED state is only set if
# a user verified and has been associated with the device.
assert device_model.user is not None
return AuthContext(user=device_model.user, device=device_model)
authenticate_external_user(external_access_token, request)
Implement external authentication.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
external_access_token |
str |
The access token used to authenticate the user to the external authenticator. |
required |
request |
Request |
The request object. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the external user could not be authorized. |
Source code in zenml/zen_server/auth.py
def authenticate_external_user(
external_access_token: str, request: Request
) -> AuthContext:
"""Implement external authentication.
Args:
external_access_token: The access token used to authenticate the user
to the external authenticator.
request: The request object.
Returns:
The authentication context reflecting the authenticated user.
Raises:
AuthorizationException: If the external user could not be authorized.
"""
config = server_config()
store = zen_store()
assert config.external_user_info_url is not None
# Use the external access token to extract the user information and
# permissions
# Get the user information from the external authenticator
user_info_url = config.external_user_info_url
headers = {"Authorization": "Bearer " + external_access_token}
query_params = dict(server_id=str(config.get_external_server_id()))
try:
auth_response = requests.get(
user_info_url,
headers=headers,
params=urlencode(query_params),
timeout=EXTERNAL_AUTHENTICATOR_TIMEOUT,
)
except Exception as e:
logger.exception(
f"Error fetching user information from external authenticator: {e}"
)
raise AuthorizationException(
"Error fetching user information from external authenticator."
)
external_user: Optional[ExternalUserModel] = None
if 200 <= auth_response.status_code < 300:
try:
payload = auth_response.json()
except requests.exceptions.JSONDecodeError:
logger.exception(
"Error decoding JSON response from external authenticator."
)
raise AuthorizationException(
"Unknown external authenticator error"
)
if isinstance(payload, dict):
try:
external_user = ExternalUserModel.model_validate(payload)
except Exception as e:
logger.exception(
f"Error parsing user information from external "
f"authenticator: {e}"
)
pass
elif auth_response.status_code in [401, 403]:
raise AuthorizationException("Not authorized to access this server.")
elif auth_response.status_code == 404:
raise AuthorizationException(
"External authenticator did not recognize this server."
)
else:
logger.error(
f"Error fetching user information from external authenticator. "
f"Status code: {auth_response.status_code}, "
f"Response: {auth_response.text}"
)
raise AuthorizationException(
"Error fetching user information from external authenticator. "
)
if not external_user:
raise AuthorizationException("Unknown external authenticator error")
# With an external user object, we can now authenticate the user against
# the ZenML server
# Check if the external user already exists in the ZenML server database
# If not, create a new user. If yes, update the existing user.
try:
user = store.get_external_user(user_id=external_user.id)
# Update the user information
user = store.update_user(
user_id=user.id,
user_update=UserUpdate(
name=external_user.email,
full_name=external_user.name or "",
email_opted_in=True,
active=True,
email=external_user.email,
is_admin=external_user.is_admin,
),
)
except KeyError:
logger.info(
f"External user with ID {external_user.id} not found in ZenML "
f"server database. Creating a new user."
)
user = store.create_user(
UserRequest(
name=external_user.email,
full_name=external_user.name or "",
external_user_id=external_user.id,
email_opted_in=True,
active=True,
email=external_user.email,
is_admin=external_user.is_admin,
)
)
with AnalyticsContext() as context:
context.user_id = user.id
context.identify(
traits={
"email": external_user.email,
"source": "external_auth",
}
)
context.alias(user_id=external_user.id, previous_id=user.id)
# This is the best spot to update the onboarding state to mark the
# "zenml login" step as completed for ZenML Pro servers, because the
# user has just successfully logged in. However, we need to differentiate
# between web clients (i.e. the dashboard) and CLI clients (i.e. the
# zenml CLI).
user_agent = request.headers.get("User-Agent", "").lower()
if "zenml/" in user_agent:
store.update_onboarding_state(
completed_steps={OnboardingStep.DEVICE_VERIFIED}
)
return AuthContext(user=user)
authentication_provider()
Returns the authentication provider.
Returns:
Type | Description |
---|---|
Callable[..., zenml.zen_server.auth.AuthContext] |
The authentication provider. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the authentication scheme is not supported. |
Source code in zenml/zen_server/auth.py
def authentication_provider() -> Callable[..., AuthContext]:
"""Returns the authentication provider.
Returns:
The authentication provider.
Raises:
ValueError: If the authentication scheme is not supported.
"""
auth_scheme = server_config().auth_scheme
if auth_scheme == AuthScheme.NO_AUTH:
return no_authentication
elif auth_scheme == AuthScheme.HTTP_BASIC:
return http_authentication
elif auth_scheme == AuthScheme.OAUTH2_PASSWORD_BEARER:
return oauth2_authentication
elif auth_scheme == AuthScheme.EXTERNAL:
return oauth2_authentication
else:
raise ValueError(f"Unknown authentication scheme: {auth_scheme}")
authorize(token=Depends(CookieOAuth2TokenBearer))
Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token |
str |
The JWT bearer token to be authenticated. |
Depends(CookieOAuth2TokenBearer) |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
noqa: DAR401
Source code in zenml/zen_server/auth.py
def oauth2_authentication(
token: str = Depends(
CookieOAuth2TokenBearer(
tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
)
),
) -> AuthContext:
"""Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Args:
token: The JWT bearer token to be authenticated.
Returns:
The authentication context reflecting the authenticated user.
# noqa: DAR401
"""
try:
auth_context = authenticate_credentials(access_token=token)
except CredentialsNotValid as e:
# We want to be very explicit here and return a CredentialsNotValid
# exception encoded as a 401 Unauthorized error encoded, so that the
# client can distinguish between a 401 error due to invalid credentials
# and other 401 errors and handle them accordingly by throwing away the
# current access token and re-authenticating.
raise http_exception_from_error(e)
return auth_context
generate_access_token(user_id, response=None, device=None, api_key=None, expires_in=None, schedule_id=None, pipeline_run_id=None, step_run_id=None)
Generates an access token for the given user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_id |
UUID |
The ID of the user. |
required |
response |
Optional[starlette.responses.Response] |
The FastAPI response object. |
None |
device |
Optional[zenml.models.v2.core.device.OAuthDeviceInternalResponse] |
The device used for authentication. |
None |
api_key |
Optional[zenml.models.v2.core.api_key.APIKeyInternalResponse] |
The service account API key used for authentication. |
None |
expires_in |
Optional[int] |
The number of seconds until the token expires. If not set, the default value is determined automatically based on the server configuration and type of token. If set to 0, the token will not expire. |
None |
schedule_id |
Optional[uuid.UUID] |
The ID of the schedule to scope the token to. |
None |
pipeline_run_id |
Optional[uuid.UUID] |
The ID of the pipeline run to scope the token to. |
None |
step_run_id |
Optional[uuid.UUID] |
The ID of the step run to scope the token to. |
None |
Returns:
Type | Description |
---|---|
OAuthTokenResponse |
An authentication response with an access token. |
Source code in zenml/zen_server/auth.py
def generate_access_token(
user_id: UUID,
response: Optional[Response] = None,
device: Optional[OAuthDeviceInternalResponse] = None,
api_key: Optional[APIKeyInternalResponse] = None,
expires_in: Optional[int] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
) -> OAuthTokenResponse:
"""Generates an access token for the given user.
Args:
user_id: The ID of the user.
response: The FastAPI response object.
device: The device used for authentication.
api_key: The service account API key used for authentication.
expires_in: The number of seconds until the token expires. If not set,
the default value is determined automatically based on the server
configuration and type of token. If set to 0, the token will not
expire.
schedule_id: The ID of the schedule to scope the token to.
pipeline_run_id: The ID of the pipeline run to scope the token to.
step_run_id: The ID of the step run to scope the token to.
Returns:
An authentication response with an access token.
"""
config = server_config()
# If the expiration time is not supplied, the JWT tokens are set to expire
# according to the values configured in the server config. Device tokens are
# handled separately from regular user tokens.
expires: Optional[datetime] = None
if expires_in == 0:
expires_in = None
elif expires_in is not None:
expires = datetime.utcnow() + timedelta(seconds=expires_in)
elif device:
# If a device was used for authentication, the token will expire
# at the same time as the device.
expires = device.expires
if expires:
expires_in = max(
int(expires.timestamp() - datetime.utcnow().timestamp()), 0
)
elif config.jwt_token_expire_minutes:
expires = datetime.utcnow() + timedelta(
minutes=config.jwt_token_expire_minutes
)
expires_in = config.jwt_token_expire_minutes * 60
access_token = JWTToken(
user_id=user_id,
device_id=device.id if device else None,
api_key_id=api_key.id if api_key else None,
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
).encode(expires=expires)
if not device and response:
# Also set the access token as an HTTP only cookie in the response
response.set_cookie(
key=config.get_auth_cookie_name(),
value=access_token,
httponly=True,
samesite="lax",
max_age=config.jwt_token_expire_minutes * 60
if config.jwt_token_expire_minutes
else None,
domain=config.auth_cookie_domain,
)
return OAuthTokenResponse(
access_token=access_token, expires_in=expires_in, token_type="bearer"
)
get_auth_context()
Returns the current authentication context.
Returns:
Type | Description |
---|---|
Optional[AuthContext] |
The authentication context. |
Source code in zenml/zen_server/auth.py
def get_auth_context() -> Optional["AuthContext"]:
"""Returns the current authentication context.
Returns:
The authentication context.
"""
auth_context = _auth_context.get()
return auth_context
http_authentication(credentials=Depends(HTTPBasic))
Authenticates any request to the ZenML Server with basic HTTP authentication.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
credentials |
HTTPBasicCredentials |
HTTP basic auth credentials passed to the request. |
Depends(HTTPBasic) |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
noqa: DAR401
Source code in zenml/zen_server/auth.py
def http_authentication(
credentials: HTTPBasicCredentials = Depends(HTTPBasic()),
) -> AuthContext:
"""Authenticates any request to the ZenML Server with basic HTTP authentication.
Args:
credentials: HTTP basic auth credentials passed to the request.
Returns:
The authentication context reflecting the authenticated user.
# noqa: DAR401
"""
try:
return authenticate_credentials(
user_name_or_id=credentials.username, password=credentials.password
)
except CredentialsNotValid as e:
# We want to be very explicit here and return a CredentialsNotValid
# exception encoded as a 401 Unauthorized error encoded, so that the
# client can distinguish between a 401 error due to invalid credentials
# and other 401 errors and handle them accordingly by throwing away the
# current access token and re-authenticating.
raise http_exception_from_error(e)
no_authentication()
Doesn't authenticate requests to the ZenML server.
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the default user. |
Source code in zenml/zen_server/auth.py
def no_authentication() -> AuthContext:
"""Doesn't authenticate requests to the ZenML server.
Returns:
The authentication context reflecting the default user.
"""
return authenticate_credentials(user_name_or_id=DEFAULT_USERNAME)
oauth2_authentication(token=Depends(CookieOAuth2TokenBearer))
Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token |
str |
The JWT bearer token to be authenticated. |
Depends(CookieOAuth2TokenBearer) |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
noqa: DAR401
Source code in zenml/zen_server/auth.py
def oauth2_authentication(
token: str = Depends(
CookieOAuth2TokenBearer(
tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
)
),
) -> AuthContext:
"""Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Args:
token: The JWT bearer token to be authenticated.
Returns:
The authentication context reflecting the authenticated user.
# noqa: DAR401
"""
try:
auth_context = authenticate_credentials(access_token=token)
except CredentialsNotValid as e:
# We want to be very explicit here and return a CredentialsNotValid
# exception encoded as a 401 Unauthorized error encoded, so that the
# client can distinguish between a 401 error due to invalid credentials
# and other 401 errors and handle them accordingly by throwing away the
# current access token and re-authenticating.
raise http_exception_from_error(e)
return auth_context
set_auth_context(auth_context)
Sets the current authentication context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_context |
AuthContext |
The authentication context. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context. |
Source code in zenml/zen_server/auth.py
def set_auth_context(auth_context: "AuthContext") -> "AuthContext":
"""Sets the current authentication context.
Args:
auth_context: The authentication context.
Returns:
The authentication context.
"""
_auth_context.set(auth_context)
return auth_context
cache
Memory cache module for the ZenML server.
MemoryCache
Simple in-memory cache with expiry and capacity management.
This cache is thread-safe and can be used in both synchronous and asynchronous contexts. It uses a simple LRU (Least Recently Used) eviction strategy to manage the cache size.
Each cache entry has a key, value, timestamp, and expiry. The cache automatically removes expired entries and evicts the oldest entry when the cache reaches its maximum capacity.
Usage Example:
cache = MemoryCache()
uuid_key = UUID("12345678123456781234567812345678")
if not cache.get(uuid_key):
# Get the value from the database or other source
value = get_value_from_database()
cache.set(uuid_key, value, expiry=60)
Usage Example with decorator:
@cache_result(expiry=60)
def get_cached_value(key: UUID) -> Any:
return get_value_from_database(key)
uuid_key = UUID("12345678123456781234567812345678")
value = get_cached_value(uuid_key)
Source code in zenml/zen_server/cache.py
class MemoryCache(metaclass=SingletonMetaClass):
"""Simple in-memory cache with expiry and capacity management.
This cache is thread-safe and can be used in both synchronous and
asynchronous contexts. It uses a simple LRU (Least Recently Used) eviction
strategy to manage the cache size.
Each cache entry has a key, value, timestamp, and expiry. The cache
automatically removes expired entries and evicts the oldest entry when
the cache reaches its maximum capacity.
Usage Example:
cache = MemoryCache()
uuid_key = UUID("12345678123456781234567812345678")
if not cache.get(uuid_key):
# Get the value from the database or other source
value = get_value_from_database()
cache.set(uuid_key, value, expiry=60)
Usage Example with decorator:
@cache_result(expiry=60)
def get_cached_value(key: UUID) -> Any:
return get_value_from_database(key)
uuid_key = UUID("12345678123456781234567812345678")
value = get_cached_value(uuid_key)
"""
def __init__(self, max_capacity: int, default_expiry: int) -> None:
"""Initialize the cache with a maximum capacity and default expiry time.
Args:
max_capacity: The maximum number of entries the cache can hold.
default_expiry: The default expiry time in seconds.
"""
self.cache: OrderedDictType[UUID, MemoryCacheEntry] = OrderedDict()
self.max_capacity = max_capacity
self.default_expiry = default_expiry
self._lock = Lock()
def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None:
"""Insert value into cache with optional custom expiry time in seconds.
Args:
key: The key to insert the value with.
value: The value to insert into the cache.
expiry: The expiry time in seconds. If None, uses the default expiry.
"""
with self._lock:
self.cache[key] = MemoryCacheEntry(
value=value, expiry=expiry or self.default_expiry
)
self._cleanup()
def get(self, key: UUID) -> Optional[Any]:
"""Retrieve value if it's still valid; otherwise, return None.
Args:
key: The key to retrieve the value for.
Returns:
The value if it's still valid; otherwise, None.
"""
with self._lock:
return self._get_internal(key)
def _get_internal(self, key: UUID) -> Optional[Any]:
"""Helper to retrieve a value without lock (internal use only).
Args:
key: The key to retrieve the value for.
Returns:
The value if it's still valid; otherwise, None.
"""
entry = self.cache.get(key)
if entry and not entry.expired:
return entry.value
elif entry:
del self.cache[key] # Invalidate expired entry
return None
def _cleanup(self) -> None:
"""Remove expired or excess entries."""
# Remove expired entries
keys_to_remove = [k for k, v in self.cache.items() if v.expired]
for k in keys_to_remove:
del self.cache[k]
# Ensure we don't exceed max capacity
while len(self.cache) > self.max_capacity:
self.cache.popitem(last=False)
__init__(self, max_capacity, default_expiry)
special
Initialize the cache with a maximum capacity and default expiry time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_capacity |
int |
The maximum number of entries the cache can hold. |
required |
default_expiry |
int |
The default expiry time in seconds. |
required |
Source code in zenml/zen_server/cache.py
def __init__(self, max_capacity: int, default_expiry: int) -> None:
"""Initialize the cache with a maximum capacity and default expiry time.
Args:
max_capacity: The maximum number of entries the cache can hold.
default_expiry: The default expiry time in seconds.
"""
self.cache: OrderedDictType[UUID, MemoryCacheEntry] = OrderedDict()
self.max_capacity = max_capacity
self.default_expiry = default_expiry
self._lock = Lock()
get(self, key)
Retrieve value if it's still valid; otherwise, return None.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
UUID |
The key to retrieve the value for. |
required |
Returns:
Type | Description |
---|---|
Optional[Any] |
The value if it's still valid; otherwise, None. |
Source code in zenml/zen_server/cache.py
def get(self, key: UUID) -> Optional[Any]:
"""Retrieve value if it's still valid; otherwise, return None.
Args:
key: The key to retrieve the value for.
Returns:
The value if it's still valid; otherwise, None.
"""
with self._lock:
return self._get_internal(key)
set(self, key, value, expiry=None)
Insert value into cache with optional custom expiry time in seconds.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
UUID |
The key to insert the value with. |
required |
value |
Any |
The value to insert into the cache. |
required |
expiry |
Optional[int] |
The expiry time in seconds. If None, uses the default expiry. |
None |
Source code in zenml/zen_server/cache.py
def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None:
"""Insert value into cache with optional custom expiry time in seconds.
Args:
key: The key to insert the value with.
value: The value to insert into the cache.
expiry: The expiry time in seconds. If None, uses the default expiry.
"""
with self._lock:
self.cache[key] = MemoryCacheEntry(
value=value, expiry=expiry or self.default_expiry
)
self._cleanup()
MemoryCacheEntry
Simple class to hold cache entry data.
Source code in zenml/zen_server/cache.py
class MemoryCacheEntry:
"""Simple class to hold cache entry data."""
def __init__(self, value: Any, expiry: int) -> None:
"""Initialize a cache entry with value and expiry time.
Args:
value: The value to store in the cache.
expiry: The expiry time in seconds.
"""
self.value: Any = value
self.expiry: int = expiry
self.timestamp: float = time.time()
@property
def expired(self) -> bool:
"""Check if the cache entry has expired.
Returns:
True if the cache entry has expired; otherwise, False.
"""
return time.time() - self.timestamp >= self.expiry
expired: bool
property
readonly
Check if the cache entry has expired.
Returns:
Type | Description |
---|---|
bool |
True if the cache entry has expired; otherwise, False. |
__init__(self, value, expiry)
special
Initialize a cache entry with value and expiry time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
The value to store in the cache. |
required |
expiry |
int |
The expiry time in seconds. |
required |
Source code in zenml/zen_server/cache.py
def __init__(self, value: Any, expiry: int) -> None:
"""Initialize a cache entry with value and expiry time.
Args:
value: The value to store in the cache.
expiry: The expiry time in seconds.
"""
self.value: Any = value
self.expiry: int = expiry
self.timestamp: float = time.time()
cache_result(expiry=None)
A decorator to cache the result of a function based on a UUID key argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expiry |
Optional[int] |
Custom time in seconds for the cache entry to expire. If None, uses the default expiry time. |
None |
Returns:
Type | Description |
---|---|
Callable[[Callable[[uuid.UUID], Any]], Callable[[uuid.UUID], Any]] |
A decorator that wraps a function, caching its results based on a UUID key. |
Source code in zenml/zen_server/cache.py
def cache_result(
expiry: Optional[int] = None,
) -> Callable[[F], F]:
"""A decorator to cache the result of a function based on a UUID key argument.
Args:
expiry: Custom time in seconds for the cache entry to expire. If None,
uses the default expiry time.
Returns:
A decorator that wraps a function, caching its results based on a UUID
key.
"""
def decorator(func: F) -> F:
"""The actual decorator that wraps the function with caching logic.
Args:
func: The function to wrap.
Returns:
The wrapped function with caching logic.
"""
def wrapper(key: UUID) -> Any:
"""The wrapped function with caching logic.
Args:
key: The key to use for caching.
Returns:
The result of the original function, either from cache or
freshly computed.
"""
from zenml.zen_server.utils import memcache
cache = memcache()
# Attempt to retrieve the result from cache
cached_value = cache.get(key)
if cached_value is not None:
logger.debug(
f"Memory cache hit for key: {key} and func: {func.__name__}"
)
return cached_value
# Call the original function and cache its result
result = func(key)
cache.set(key, result, expiry)
return result
return wrapper
return decorator
cloud_utils
Utils concerning anything concerning the cloud control plane backend.
ZenMLCloudConfiguration (BaseModel)
ZenML Pro RBAC configuration.
Source code in zenml/zen_server/cloud_utils.py
class ZenMLCloudConfiguration(BaseModel):
"""ZenML Pro RBAC configuration."""
api_url: str
oauth2_client_id: str
oauth2_client_secret: str
oauth2_audience: str
@field_validator("api_url")
@classmethod
def _strip_trailing_slashes_url(cls, url: str) -> str:
"""Strip any trailing slashes on the API URL.
Args:
url: The API URL.
Returns:
The API URL with potential trailing slashes removed.
"""
return url.rstrip("/")
@classmethod
def from_environment(cls) -> "ZenMLCloudConfiguration":
"""Get the RBAC configuration from environment variables.
Returns:
The RBAC configuration.
"""
env_config: Dict[str, Any] = {}
for k, v in os.environ.items():
if v == "":
continue
if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v
return ZenMLCloudConfiguration(**env_config)
model_config = ConfigDict(
# Allow extra attributes from configs of previous ZenML versions to
# permit downgrading
extra="allow"
)
from_environment()
classmethod
Get the RBAC configuration from environment variables.
Returns:
Type | Description |
---|---|
ZenMLCloudConfiguration |
The RBAC configuration. |
Source code in zenml/zen_server/cloud_utils.py
@classmethod
def from_environment(cls) -> "ZenMLCloudConfiguration":
"""Get the RBAC configuration from environment variables.
Returns:
The RBAC configuration.
"""
env_config: Dict[str, Any] = {}
for k, v in os.environ.items():
if v == "":
continue
if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v
return ZenMLCloudConfiguration(**env_config)
ZenMLCloudConnection
Class to use for communication between server and control plane.
Source code in zenml/zen_server/cloud_utils.py
class ZenMLCloudConnection:
"""Class to use for communication between server and control plane."""
def __init__(self) -> None:
"""Initialize the RBAC component."""
self._config = ZenMLCloudConfiguration.from_environment()
self._session: Optional[requests.Session] = None
self._token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
def get(
self, endpoint: str, params: Optional[Dict[str, Any]]
) -> requests.Response:
"""Send a GET request using the active session.
Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
Raises:
RuntimeError: If the request failed.
SubscriptionUpgradeRequiredError: In case the current subscription
tier is insufficient for the attempted operation.
Returns:
The response.
"""
url = self._config.api_url + endpoint
response = self.session.get(url=url, params=params, timeout=7)
if response.status_code == 401:
# If we get an Unauthorized error from the API serer, we refresh the
# auth token and try again
self._clear_session()
response = self.session.get(url=url, params=params, timeout=7)
try:
response.raise_for_status()
except requests.HTTPError:
if response.status_code == 402:
raise SubscriptionUpgradeRequiredError(response.json())
else:
raise RuntimeError(
f"Failed with the following error {response} {response.text}"
)
return response
def post(
self,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> requests.Response:
"""Send a POST request using the active session.
Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
data: Data to include in the request.
Raises:
RuntimeError: If the request failed.
Returns:
The response.
"""
url = self._config.api_url + endpoint
response = self.session.post(
url=url, params=params, json=data, timeout=7
)
if response.status_code == 401:
# Refresh the auth token and try again
self._clear_session()
response = self.session.post(
url=url, params=params, json=data, timeout=7
)
try:
response.raise_for_status()
except requests.HTTPError as e:
raise RuntimeError(
f"Failed while trying to contact the central zenml pro "
f"service: {e}"
)
return response
@property
def session(self) -> requests.Session:
"""Authenticate to the ZenML Pro Management Plane.
Returns:
A requests session with the authentication token.
"""
if self._session is None:
# Set up the session's connection pool size to match the server's
# thread pool size. This allows the server to cache one connection
# per thread, which means we can keep connections open for longer
# and avoid the overhead of setting up a new connection for each
# request.
conn_pool_size = server_config().thread_pool_size
self._session = requests.Session()
token = self._fetch_auth_token()
self._session.headers.update({"Authorization": "Bearer " + token})
retries = Retry(
total=5, backoff_factor=0.1, status_forcelist=[502, 504]
)
self._session.mount(
"https://",
HTTPAdapter(
max_retries=retries,
# We only use one connection pool to be cached because we
# only communicate with one remote server (the control
# plane)
pool_connections=1,
pool_maxsize=conn_pool_size,
),
)
return self._session
def _clear_session(self) -> None:
"""Clear the authentication session."""
self._session = None
self._token = None
self._token_expires_at = None
def _fetch_auth_token(self) -> str:
"""Fetch an auth token for the Cloud API from auth0.
Raises:
RuntimeError: If the auth token can't be fetched.
Returns:
Auth token.
"""
if (
self._token is not None
and self._token_expires_at is not None
and datetime.now(timezone.utc) + timedelta(minutes=5)
< self._token_expires_at
):
return self._token
# Get an auth token from auth0
login_url = f"{self._config.api_url}/auth/login"
headers = {"content-type": "application/x-www-form-urlencoded"}
payload = {
"client_id": self._config.oauth2_client_id,
"client_secret": self._config.oauth2_client_secret,
"audience": self._config.oauth2_audience,
"grant_type": "client_credentials",
}
try:
response = requests.post(
login_url, headers=headers, data=payload, timeout=7
)
response.raise_for_status()
except Exception as e:
raise RuntimeError(f"Error fetching auth token from auth0: {e}")
json_response = response.json()
access_token = json_response.get("access_token", "")
expires_in = json_response.get("expires_in", 0)
if (
not access_token
or not isinstance(access_token, str)
or not expires_in
or not isinstance(expires_in, int)
):
raise RuntimeError("Could not fetch auth token from auth0.")
self._token = access_token
self._token_expires_at = datetime.now(timezone.utc) + timedelta(
seconds=expires_in
)
assert self._token is not None
return self._token
session: Session
property
readonly
Authenticate to the ZenML Pro Management Plane.
Returns:
Type | Description |
---|---|
Session |
A requests session with the authentication token. |
__init__(self)
special
Initialize the RBAC component.
Source code in zenml/zen_server/cloud_utils.py
def __init__(self) -> None:
"""Initialize the RBAC component."""
self._config = ZenMLCloudConfiguration.from_environment()
self._session: Optional[requests.Session] = None
self._token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
get(self, endpoint, params)
Send a GET request using the active session.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
endpoint |
str |
The endpoint to send the request to. This will be appended to the base URL. |
required |
params |
Optional[Dict[str, Any]] |
Parameters to include in the request. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the request failed. |
SubscriptionUpgradeRequiredError |
In case the current subscription tier is insufficient for the attempted operation. |
Returns:
Type | Description |
---|---|
Response |
The response. |
Source code in zenml/zen_server/cloud_utils.py
def get(
self, endpoint: str, params: Optional[Dict[str, Any]]
) -> requests.Response:
"""Send a GET request using the active session.
Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
Raises:
RuntimeError: If the request failed.
SubscriptionUpgradeRequiredError: In case the current subscription
tier is insufficient for the attempted operation.
Returns:
The response.
"""
url = self._config.api_url + endpoint
response = self.session.get(url=url, params=params, timeout=7)
if response.status_code == 401:
# If we get an Unauthorized error from the API serer, we refresh the
# auth token and try again
self._clear_session()
response = self.session.get(url=url, params=params, timeout=7)
try:
response.raise_for_status()
except requests.HTTPError:
if response.status_code == 402:
raise SubscriptionUpgradeRequiredError(response.json())
else:
raise RuntimeError(
f"Failed with the following error {response} {response.text}"
)
return response
post(self, endpoint, params=None, data=None)
Send a POST request using the active session.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
endpoint |
str |
The endpoint to send the request to. This will be appended to the base URL. |
required |
params |
Optional[Dict[str, Any]] |
Parameters to include in the request. |
None |
data |
Optional[Dict[str, Any]] |
Data to include in the request. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the request failed. |
Returns:
Type | Description |
---|---|
Response |
The response. |
Source code in zenml/zen_server/cloud_utils.py
def post(
self,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> requests.Response:
"""Send a POST request using the active session.
Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
data: Data to include in the request.
Raises:
RuntimeError: If the request failed.
Returns:
The response.
"""
url = self._config.api_url + endpoint
response = self.session.post(
url=url, params=params, json=data, timeout=7
)
if response.status_code == 401:
# Refresh the auth token and try again
self._clear_session()
response = self.session.post(
url=url, params=params, json=data, timeout=7
)
try:
response.raise_for_status()
except requests.HTTPError as e:
raise RuntimeError(
f"Failed while trying to contact the central zenml pro "
f"service: {e}"
)
return response
cloud_connection()
Return the initialized cloud connection.
Returns:
Type | Description |
---|---|
ZenMLCloudConnection |
The cloud connection. |
Source code in zenml/zen_server/cloud_utils.py
def cloud_connection() -> ZenMLCloudConnection:
"""Return the initialized cloud connection.
Returns:
The cloud connection.
"""
global _cloud_connection
if _cloud_connection is None:
_cloud_connection = ZenMLCloudConnection()
return _cloud_connection
deploy
special
ZenML server deployments.
base_provider
Base ZenML server provider class.
BaseServerProvider (ABC)
Base ZenML server provider class.
All ZenML server providers must extend and implement this base class.
Source code in zenml/zen_server/deploy/base_provider.py
class BaseServerProvider(ABC):
"""Base ZenML server provider class.
All ZenML server providers must extend and implement this base class.
"""
TYPE: ClassVar[ServerProviderType]
CONFIG_TYPE: ClassVar[Type[LocalServerDeploymentConfig]] = (
LocalServerDeploymentConfig
)
@classmethod
def register_as_provider(cls) -> None:
"""Register the class as a server provider."""
from zenml.zen_server.deploy.deployer import LocalServerDeployer
LocalServerDeployer.register_provider(cls)
@classmethod
def _convert_config(
cls, config: LocalServerDeploymentConfig
) -> LocalServerDeploymentConfig:
"""Convert a generic server deployment config into a provider specific config.
Args:
config: The generic server deployment config.
Returns:
The provider specific server deployment config.
Raises:
ServerDeploymentConfigurationError: If the configuration is not
valid.
"""
if isinstance(config, cls.CONFIG_TYPE):
return config
try:
return cls.CONFIG_TYPE(**config.model_dump())
except ValidationError as e:
raise ServerDeploymentConfigurationError(
f"Invalid configuration for provider {cls.TYPE.value}: {e}"
)
def deploy_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> LocalServerDeployment:
"""Deploy a new ZenML server.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The newly created server deployment.
Raises:
ServerDeploymentExistsError: If a deployment already exists.
"""
try:
self._get_service()
except KeyError:
pass
else:
raise ServerDeploymentExistsError(
f"Local {self.TYPE.value} ZenML server deployment already exists"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
service = self._create_service(config, timeout)
return self._get_deployment(service)
def update_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> LocalServerDeployment:
"""Update an existing ZenML server deployment.
Args:
config: The new generic server deployment configuration.
timeout: The timeout in seconds to wait until the update is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The updated server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
old_config = self._get_deployment_config(service)
if old_config == config:
logger.info(
f"The local {self.TYPE.value} ZenML server is already "
"configured with the same parameters."
)
service = self._start_service(service, timeout)
else:
logger.info(f"Updating the local {self.TYPE.value} ZenML server.")
service = self._update_service(service, config, timeout)
return self._get_deployment(service)
def remove_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with a ZenML server deployment.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the server is
removed. If not supplied, the default timeout value specified
by the provider is used.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
logger.info(f"Shutting down the local {self.TYPE.value} ZenML server.")
self._delete_service(service, timeout)
def get_server(
self,
config: LocalServerDeploymentConfig,
) -> LocalServerDeployment:
"""Retrieve information about a ZenML server deployment.
Args:
config: The generic server deployment configuration.
Returns:
The server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
return self._get_deployment(service)
def get_server_logs(
self,
config: LocalServerDeploymentConfig,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs of a ZenML server.
Args:
config: The generic server deployment configuration.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
return service.get_logs(follow=follow, tail=tail)
def _get_deployment_status(
self, service: BaseService
) -> LocalServerDeploymentStatus:
"""Get the status of a server deployment from its service.
Args:
service: The server deployment service.
Returns:
The status of the server deployment.
"""
gc = GlobalConfiguration()
url: Optional[str] = None
if service.is_running:
# all services must have an endpoint
assert service.endpoint is not None
url = service.endpoint.status.uri
connected = url is not None and gc.store_configuration.url == url
return LocalServerDeploymentStatus(
url=url,
status=service.status.state,
status_message=service.status.last_error,
connected=connected,
)
def _get_deployment(self, service: BaseService) -> LocalServerDeployment:
"""Get the server deployment associated with a service.
Args:
service: The service.
Returns:
The server deployment.
"""
config = self._get_deployment_config(service)
return LocalServerDeployment(
config=config,
status=self._get_deployment_status(service),
)
@classmethod
@abstractmethod
def _get_service_configuration(
cls,
server_config: LocalServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service, service endpoint and endpoint monitor configuration.
"""
@abstractmethod
def _create_service(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return a service instance for a ZenML server deployment.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The service instance.
"""
@abstractmethod
def _update_service(
self,
service: BaseService,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update an existing service instance for a ZenML server deployment.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The updated service instance.
"""
@abstractmethod
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start a service instance for a ZenML server deployment.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The updated service instance.
"""
@abstractmethod
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop a service instance for a ZenML server deployment.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The updated service instance.
"""
@abstractmethod
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove a service instance for a ZenML server deployment.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed. If not supplied, a default timeout value specified
by the provider implementation should be used.
"""
@abstractmethod
def _get_service(self) -> BaseService:
"""Get the service instance associated with a ZenML server deployment.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
@abstractmethod
def _get_deployment_config(
self, service: BaseService
) -> LocalServerDeploymentConfig:
"""Recreate the server deployment config from a service instance.
Args:
service: The service instance.
Returns:
The server deployment config.
"""
CONFIG_TYPE (BaseModel)
Generic local server deployment configuration.
All local server deployment configurations should inherit from this class and handle extra attributes as provider specific attributes.
Attributes:
Name | Type | Description |
---|---|---|
provider |
ServerProviderType |
The server provider type. |
Source code in zenml/zen_server/deploy/base_provider.py
class LocalServerDeploymentConfig(BaseModel):
"""Generic local server deployment configuration.
All local server deployment configurations should inherit from this class
and handle extra attributes as provider specific attributes.
Attributes:
provider: The server provider type.
"""
provider: ServerProviderType
@property
def url(self) -> Optional[str]:
"""Get the configured server URL.
Returns:
The configured server URL.
"""
return None
model_config = ConfigDict(
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment=True,
# Allow extra attributes to be set in the base class. The concrete
# classes are responsible for validating the attributes.
extra="allow",
)
url: Optional[str]
property
readonly
Get the configured server URL.
Returns:
Type | Description |
---|---|
Optional[str] |
The configured server URL. |
deploy_server(self, config, timeout=None)
Deploy a new ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The generic server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successful. If not supplied, the default timeout value specified by the provider is used. |
None |
Returns:
Type | Description |
---|---|
LocalServerDeployment |
The newly created server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentExistsError |
If a deployment already exists. |
Source code in zenml/zen_server/deploy/base_provider.py
def deploy_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> LocalServerDeployment:
"""Deploy a new ZenML server.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The newly created server deployment.
Raises:
ServerDeploymentExistsError: If a deployment already exists.
"""
try:
self._get_service()
except KeyError:
pass
else:
raise ServerDeploymentExistsError(
f"Local {self.TYPE.value} ZenML server deployment already exists"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
service = self._create_service(config, timeout)
return self._get_deployment(service)
get_server(self, config)
Retrieve information about a ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The generic server deployment configuration. |
required |
Returns:
Type | Description |
---|---|
LocalServerDeployment |
The server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def get_server(
self,
config: LocalServerDeploymentConfig,
) -> LocalServerDeployment:
"""Retrieve information about a ZenML server deployment.
Args:
config: The generic server deployment configuration.
Returns:
The server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
return self._get_deployment(service)
get_server_logs(self, config, follow=False, tail=None)
Retrieve the logs of a ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The generic server deployment configuration. |
required |
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def get_server_logs(
self,
config: LocalServerDeploymentConfig,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs of a ZenML server.
Args:
config: The generic server deployment configuration.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
return service.get_logs(follow=follow, tail=tail)
register_as_provider()
classmethod
Register the class as a server provider.
Source code in zenml/zen_server/deploy/base_provider.py
@classmethod
def register_as_provider(cls) -> None:
"""Register the class as a server provider."""
from zenml.zen_server.deploy.deployer import LocalServerDeployer
LocalServerDeployer.register_provider(cls)
remove_server(self, config, timeout=None)
Tears down and removes all resources and files associated with a ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The generic server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the server is removed. If not supplied, the default timeout value specified by the provider is used. |
None |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def remove_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with a ZenML server deployment.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the server is
removed. If not supplied, the default timeout value specified
by the provider is used.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
logger.info(f"Shutting down the local {self.TYPE.value} ZenML server.")
self._delete_service(service, timeout)
update_server(self, config, timeout=None)
Update an existing ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The new generic server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the update is successful. If not supplied, the default timeout value specified by the provider is used. |
None |
Returns:
Type | Description |
---|---|
LocalServerDeployment |
The updated server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def update_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> LocalServerDeployment:
"""Update an existing ZenML server deployment.
Args:
config: The new generic server deployment configuration.
timeout: The timeout in seconds to wait until the update is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The updated server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment doesn't exist.
"""
try:
service = self._get_service()
except KeyError:
raise ServerDeploymentNotFoundError(
f"The local {self.TYPE.value} ZenML server deployment was not "
f"found"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
old_config = self._get_deployment_config(service)
if old_config == config:
logger.info(
f"The local {self.TYPE.value} ZenML server is already "
"configured with the same parameters."
)
service = self._start_service(service, timeout)
else:
logger.info(f"Updating the local {self.TYPE.value} ZenML server.")
service = self._update_service(service, config, timeout)
return self._get_deployment(service)
daemon
special
ZenML Server Local Daemon Deployment.
daemon_provider
Zen Server daemon provider implementation.
DaemonServerProvider (BaseServerProvider)
Daemon ZenML server provider.
Source code in zenml/zen_server/deploy/daemon/daemon_provider.py
class DaemonServerProvider(BaseServerProvider):
"""Daemon ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.DAEMON
CONFIG_TYPE: ClassVar[Type[LocalServerDeploymentConfig]] = (
DaemonServerDeploymentConfig
)
@staticmethod
def check_local_server_dependencies() -> None:
"""Check if local server dependencies are installed.
Raises:
RuntimeError: If the dependencies are not installed.
"""
try:
# Make sure the ZenML Server dependencies are installed
import fastapi # noqa
import jwt # noqa
import multipart # noqa
import uvicorn # noqa
except ImportError:
# Unable to import the ZenML Server dependencies.
raise RuntimeError(
"The local daemon ZenML server provider is unavailable because the "
"ZenML server requirements seems to be unavailable on your machine. "
"This is probably because ZenML was installed without the optional "
"ZenML Server dependencies. To install the missing dependencies "
f'run `pip install "zenml[server]=={__version__}"`.'
)
@classmethod
def _get_service_configuration(
cls,
server_config: LocalServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service, service endpoint and endpoint monitor configuration.
"""
assert isinstance(server_config, DaemonServerDeploymentConfig)
return (
DaemonZenServerConfig(
root_runtime_path=DaemonZenServer.config_path(),
singleton=True,
name=ServerProviderType.DAEMON.value,
blocking=server_config.blocking,
server=server_config,
),
LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
ip_address=str(server_config.ip_address),
port=server_config.port,
allocate_port=False,
),
HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
use_head_request=True,
),
)
def _create_service(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return the local daemon ZenML server deployment service.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The service instance.
Raises:
RuntimeError: If a local service is already running.
"""
assert isinstance(config, DaemonServerDeploymentConfig)
if timeout is None:
timeout = DAEMON_ZENML_SERVER_DEFAULT_TIMEOUT
self.check_local_server_dependencies()
existing_service = DaemonZenServer.get_service()
if existing_service:
raise RuntimeError(
f"A local daemon ZenML server with name '{existing_service.config.name}' "
f"is already running. Please stop it first before starting a "
f"new one."
)
(
service_config,
endpoint_cfg,
monitor_cfg,
) = self._get_service_configuration(config)
endpoint = LocalDaemonServiceEndpoint(
config=endpoint_cfg,
monitor=HTTPEndpointHealthMonitor(
config=monitor_cfg,
),
)
service = DaemonZenServer(
uuid=uuid4(), config=service_config, endpoint=endpoint
)
service.start(timeout=timeout)
return service
def _update_service(
self,
service: BaseService,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update the local daemon ZenML server deployment service.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DAEMON_ZENML_SERVER_DEFAULT_TIMEOUT
(
new_config,
new_endpoint_cfg,
new_monitor_cfg,
) = self._get_service_configuration(config)
assert service.endpoint
assert service.endpoint.monitor
service.stop(timeout=timeout)
(
service.config,
service.endpoint.config,
service.endpoint.monitor.config,
) = (
new_config,
new_endpoint_cfg,
new_monitor_cfg,
)
service.start(timeout=timeout)
return service
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start the local daemon ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DAEMON_ZENML_SERVER_DEFAULT_TIMEOUT
service.start(timeout=timeout)
return service
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop the local daemon ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DAEMON_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout=timeout)
return service
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove the local daemon ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed.
"""
assert isinstance(service, DaemonZenServer)
if timeout is None:
timeout = DAEMON_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout)
shutil.rmtree(DaemonZenServer.config_path())
def _get_service(self) -> BaseService:
"""Get the local daemon ZenML server deployment service.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
service = DaemonZenServer.get_service()
if service is None:
raise KeyError("The local daemon ZenML server is not deployed.")
return service
def _get_deployment_config(
self, service: BaseService
) -> LocalServerDeploymentConfig:
"""Recreate the server deployment configuration from a service instance.
Args:
service: The service instance.
Returns:
The server deployment configuration.
"""
server = cast(DaemonZenServer, service)
return server.config.server
CONFIG_TYPE (LocalServerDeploymentConfig)
Daemon server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
address |
The IP address where the server is reachable. |
|
blocking |
bool |
Run the server in blocking mode instead of using a daemon process. |
Source code in zenml/zen_server/deploy/daemon/daemon_provider.py
class DaemonServerDeploymentConfig(LocalServerDeploymentConfig):
"""Daemon server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
address: The IP address where the server is reachable.
blocking: Run the server in blocking mode instead of using a daemon
process.
"""
port: int = 8237
ip_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = Field(
default=ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS),
union_mode="left_to_right",
)
blocking: bool = False
store: Optional[StoreConfiguration] = None
@property
def url(self) -> Optional[str]:
"""Get the configured server URL.
Returns:
The configured server URL.
"""
return f"http://{self.ip_address}:{self.port}"
model_config = ConfigDict(extra="forbid")
url: Optional[str]
property
readonly
Get the configured server URL.
Returns:
Type | Description |
---|---|
Optional[str] |
The configured server URL. |
check_local_server_dependencies()
staticmethod
Check if local server dependencies are installed.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the dependencies are not installed. |
Source code in zenml/zen_server/deploy/daemon/daemon_provider.py
@staticmethod
def check_local_server_dependencies() -> None:
"""Check if local server dependencies are installed.
Raises:
RuntimeError: If the dependencies are not installed.
"""
try:
# Make sure the ZenML Server dependencies are installed
import fastapi # noqa
import jwt # noqa
import multipart # noqa
import uvicorn # noqa
except ImportError:
# Unable to import the ZenML Server dependencies.
raise RuntimeError(
"The local daemon ZenML server provider is unavailable because the "
"ZenML server requirements seems to be unavailable on your machine. "
"This is probably because ZenML was installed without the optional "
"ZenML Server dependencies. To install the missing dependencies "
f'run `pip install "zenml[server]=={__version__}"`.'
)
daemon_zen_server
Local daemon ZenML server deployment service implementation.
DaemonServerDeploymentConfig (LocalServerDeploymentConfig)
Daemon server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
address |
The IP address where the server is reachable. |
|
blocking |
bool |
Run the server in blocking mode instead of using a daemon process. |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
class DaemonServerDeploymentConfig(LocalServerDeploymentConfig):
"""Daemon server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
address: The IP address where the server is reachable.
blocking: Run the server in blocking mode instead of using a daemon
process.
"""
port: int = 8237
ip_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = Field(
default=ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS),
union_mode="left_to_right",
)
blocking: bool = False
store: Optional[StoreConfiguration] = None
@property
def url(self) -> Optional[str]:
"""Get the configured server URL.
Returns:
The configured server URL.
"""
return f"http://{self.ip_address}:{self.port}"
model_config = ConfigDict(extra="forbid")
url: Optional[str]
property
readonly
Get the configured server URL.
Returns:
Type | Description |
---|---|
Optional[str] |
The configured server URL. |
DaemonZenServer (LocalDaemonService)
Service daemon that can be used to start a local daemon ZenML server.
Attributes:
Name | Type | Description |
---|---|---|
config |
DaemonZenServerConfig |
service configuration |
endpoint |
LocalDaemonServiceEndpoint |
optional service endpoint |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
class DaemonZenServer(LocalDaemonService):
"""Service daemon that can be used to start a local daemon ZenML server.
Attributes:
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="daemon_zenml_server",
type="zen_server",
flavor="daemon",
description="local daemon ZenML server deployment",
)
config: DaemonZenServerConfig
endpoint: LocalDaemonServiceEndpoint
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the local daemon ZenML server files are located.
Returns:
Path to the local daemon ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"daemon",
)
@property
def _global_config_path(self) -> str:
"""Path to the global configuration directory used by this server.
Returns:
Path to the global configuration directory used by this server.
"""
return os.path.join(self.config_path(), ".zenconfig")
@classmethod
def get_service(cls) -> Optional["DaemonZenServer"]:
"""Load and return the local daemon ZenML server service, if present.
Returns:
The local daemon ZenML server service or None, if the local server
deployment is not found.
"""
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
"DaemonZenServer", DaemonZenServer.from_json(f.read())
)
except FileNotFoundError:
return None
def _get_daemon_cmd(self) -> Tuple[List[str], Dict[str, str]]:
"""Get the command to start the daemon.
Overrides the base class implementation to add the environment variable
that forces the ZenML server to use the copied global config.
Returns:
The command to start the daemon and the environment variables to
set for the command.
"""
gc = GlobalConfiguration()
cmd, env = super()._get_daemon_cmd()
env[ENV_ZENML_SERVER] = "true"
env[ENV_ZENML_CONFIG_PATH] = self._global_config_path
env[ENV_ZENML_ANALYTICS_OPT_IN] = str(gc.analytics_opt_in)
env[ENV_ZENML_USER_ID] = str(gc.user_id)
# Disable authentication for the local server
env[ENV_ZENML_SERVER_AUTH_SCHEME] = AuthScheme.NO_AUTH.value
env[ENV_ZENML_SERVER_DEPLOYMENT_TYPE] = ServerDeploymentType.LOCAL
# Set the local stores path to the same path used by the client. This
# ensures that the server's default store configuration is initialized
# to point at the same local SQLite database as the client.
env[ENV_ZENML_LOCAL_STORES_PATH] = (
GlobalConfiguration().local_stores_path
)
env[ENV_ZENML_DISABLE_DATABASE_MIGRATION] = "True"
env[ENV_ZENML_SERVER_AUTO_ACTIVATE] = "True"
return cmd, env
def provision(self) -> None:
"""Provision the service."""
super().provision()
def start(self, timeout: int = 0) -> None:
"""Start the service and optionally wait for it to become active.
Args:
timeout: amount of time to wait for the service to become active.
If set to 0, the method will return immediately after checking
the service status.
"""
if not self.config.blocking:
super().start(timeout)
else:
gc = GlobalConfiguration()
# In the blocking mode, we need to temporarily set the environment
# variables for the running process to make it look like the server
# is running in a separate environment (i.e. using a different
# global configuration path). This is necessary to avoid polluting
# the client environment with the server's configuration.
local_stores_path = GlobalConfiguration().local_stores_path
GlobalConfiguration._reset_instance()
Client._reset_instance()
original_config_path = os.environ.get(ENV_ZENML_CONFIG_PATH)
os.environ[ENV_ZENML_SERVER] = "true"
os.environ[ENV_ZENML_CONFIG_PATH] = self._global_config_path
os.environ[ENV_ZENML_ANALYTICS_OPT_IN] = str(gc.analytics_opt_in)
os.environ[ENV_ZENML_USER_ID] = str(gc.user_id)
# Set the local stores path to the same path used by the client.
# This ensures that the server's default store configuration is
# initialized to point at the same local SQLite database as the
# client.
os.environ[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
os.environ[ENV_ZENML_SERVER_AUTH_SCHEME] = AuthScheme.NO_AUTH.value
try:
self.run()
finally:
# Restore the original client environment variables
del os.environ[ENV_ZENML_SERVER]
if original_config_path:
os.environ[ENV_ZENML_CONFIG_PATH] = original_config_path
else:
del os.environ[ENV_ZENML_CONFIG_PATH]
del os.environ[ENV_ZENML_LOCAL_STORES_PATH]
del os.environ[ENV_ZENML_SERVER_AUTH_SCHEME]
GlobalConfiguration._reset_instance()
Client._reset_instance()
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store_configuration.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host=self.endpoint.config.ip_address,
port=self.endpoint.config.port or 8000,
log_level="info",
server_header=False,
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
config_path()
classmethod
Path to the directory where the local daemon ZenML server files are located.
Returns:
Type | Description |
---|---|
str |
Path to the local daemon ZenML server runtime directory. |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the local daemon ZenML server files are located.
Returns:
Path to the local daemon ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"daemon",
)
get_service()
classmethod
Load and return the local daemon ZenML server service, if present.
Returns:
Type | Description |
---|---|
Optional[DaemonZenServer] |
The local daemon ZenML server service or None, if the local server deployment is not found. |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
@classmethod
def get_service(cls) -> Optional["DaemonZenServer"]:
"""Load and return the local daemon ZenML server service, if present.
Returns:
The local daemon ZenML server service or None, if the local server
deployment is not found.
"""
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
"DaemonZenServer", DaemonZenServer.from_json(f.read())
)
except FileNotFoundError:
return None
provision(self)
Provision the service.
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
def provision(self) -> None:
"""Provision the service."""
super().provision()
run(self)
Run the ZenML Server.
Exceptions:
Type | Description |
---|---|
ValueError |
if started with a global configuration that connects to another ZenML server. |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store_configuration.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host=self.endpoint.config.ip_address,
port=self.endpoint.config.port or 8000,
log_level="info",
server_header=False,
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
start(self, timeout=0)
Start the service and optionally wait for it to become active.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timeout |
int |
amount of time to wait for the service to become active. If set to 0, the method will return immediately after checking the service status. |
0 |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
def start(self, timeout: int = 0) -> None:
"""Start the service and optionally wait for it to become active.
Args:
timeout: amount of time to wait for the service to become active.
If set to 0, the method will return immediately after checking
the service status.
"""
if not self.config.blocking:
super().start(timeout)
else:
gc = GlobalConfiguration()
# In the blocking mode, we need to temporarily set the environment
# variables for the running process to make it look like the server
# is running in a separate environment (i.e. using a different
# global configuration path). This is necessary to avoid polluting
# the client environment with the server's configuration.
local_stores_path = GlobalConfiguration().local_stores_path
GlobalConfiguration._reset_instance()
Client._reset_instance()
original_config_path = os.environ.get(ENV_ZENML_CONFIG_PATH)
os.environ[ENV_ZENML_SERVER] = "true"
os.environ[ENV_ZENML_CONFIG_PATH] = self._global_config_path
os.environ[ENV_ZENML_ANALYTICS_OPT_IN] = str(gc.analytics_opt_in)
os.environ[ENV_ZENML_USER_ID] = str(gc.user_id)
# Set the local stores path to the same path used by the client.
# This ensures that the server's default store configuration is
# initialized to point at the same local SQLite database as the
# client.
os.environ[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
os.environ[ENV_ZENML_SERVER_AUTH_SCHEME] = AuthScheme.NO_AUTH.value
try:
self.run()
finally:
# Restore the original client environment variables
del os.environ[ENV_ZENML_SERVER]
if original_config_path:
os.environ[ENV_ZENML_CONFIG_PATH] = original_config_path
else:
del os.environ[ENV_ZENML_CONFIG_PATH]
del os.environ[ENV_ZENML_LOCAL_STORES_PATH]
del os.environ[ENV_ZENML_SERVER_AUTH_SCHEME]
GlobalConfiguration._reset_instance()
Client._reset_instance()
DaemonZenServerConfig (LocalDaemonServiceConfig)
Local daemon Zen server configuration.
Attributes:
Name | Type | Description |
---|---|---|
server |
DaemonServerDeploymentConfig |
The deployment configuration. |
Source code in zenml/zen_server/deploy/daemon/daemon_zen_server.py
class DaemonZenServerConfig(LocalDaemonServiceConfig):
"""Local daemon Zen server configuration.
Attributes:
server: The deployment configuration.
"""
server: DaemonServerDeploymentConfig
deployer
ZenML server deployer singleton implementation.
LocalServerDeployer
Local server deployer singleton.
This class is responsible for managing the various server provider implementations and for directing server deployment lifecycle requests to the responsible provider. It acts as a facade built on top of the various server providers.
Source code in zenml/zen_server/deploy/deployer.py
class LocalServerDeployer(metaclass=SingletonMetaClass):
"""Local server deployer singleton.
This class is responsible for managing the various server provider
implementations and for directing server deployment lifecycle requests to
the responsible provider. It acts as a facade built on top of the various
server providers.
"""
_providers: ClassVar[Dict[ServerProviderType, BaseServerProvider]] = {}
@classmethod
def register_provider(cls, provider: Type[BaseServerProvider]) -> None:
"""Register a server provider.
Args:
provider: The server provider to register.
Raises:
TypeError: If a provider with the same type is already registered.
"""
if provider.TYPE in cls._providers:
raise TypeError(
f"Server provider '{provider.TYPE}' is already registered."
)
logger.debug(f"Registering server provider '{provider.TYPE}'.")
cls._providers[provider.TYPE] = provider()
@classmethod
def get_provider(
cls, provider_type: ServerProviderType
) -> BaseServerProvider:
"""Get the server provider associated with a provider type.
Args:
provider_type: The server provider type.
Returns:
The server provider associated with the provider type.
Raises:
ServerProviderNotFoundError: If no provider is registered for the
given provider type.
"""
if provider_type not in cls._providers:
raise ServerProviderNotFoundError(
f"Server provider '{provider_type}' is not registered."
)
return cls._providers[provider_type]
def initialize_local_database(self) -> None:
"""Initialize the local ZenML database."""
default_store_cfg = GlobalConfiguration().get_default_store()
BaseZenStore.create_store(default_store_cfg)
def deploy_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
restart: bool = False,
) -> LocalServerDeployment:
"""Deploy the local ZenML server or update the existing deployment.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
restart: If True, the existing server deployment will be torn down
and a new server will be deployed.
Returns:
The local server deployment.
"""
# Ensure that the local database is always initialized before any local
# server is deployed or updated.
self.initialize_local_database()
try:
self.get_server()
except ServerDeploymentNotFoundError:
pass
else:
return self.update_server(
config=config, timeout=timeout, restart=restart
)
provider_name = config.provider.value
provider = self.get_provider(config.provider)
logger.info(f"Deploying a local {provider_name} ZenML server.")
return provider.deploy_server(config, timeout=timeout)
def update_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
restart: bool = False,
) -> LocalServerDeployment:
"""Update an existing local ZenML server deployment.
Args:
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, a default timeout value of 30
seconds is used.
restart: If True, the existing server deployment will be torn down
and a new server will be deployed.
Returns:
The updated server deployment.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
existing_server = self.get_server()
provider = self.get_provider(config.provider)
if existing_server.config.provider != config.provider or restart:
existing_provider = self.get_provider(
existing_server.config.provider
)
# Tear down the existing server deployment
existing_provider.remove_server(
existing_server.config, timeout=timeout
)
# Deploy a new server with the new provider
return provider.deploy_server(config, timeout=timeout)
return provider.update_server(config, timeout=timeout)
def remove_server(
self,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with the local ZenML server deployment.
Args:
timeout: The timeout in seconds to wait until the deployment is
successfully torn down. If not supplied, a provider specific
default timeout value is used.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
try:
server = self.get_server()
except ServerDeploymentNotFoundError:
return
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
if self.is_connected_to_server():
try:
self.disconnect_from_server()
except Exception as e:
logger.warning(
f"Failed to disconnect from the local server: {e}"
)
logger.info(f"Tearing down the local {provider_name} ZenML server.")
provider.remove_server(server.config, timeout=timeout)
def is_connected_to_server(self) -> bool:
"""Check if the ZenML client is currently connected to the local ZenML server.
Returns:
True if the ZenML client is connected to the local ZenML server, False
otherwise.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
try:
server = self.get_server()
except ServerDeploymentNotFoundError:
return False
gc = GlobalConfiguration()
return (
server.status is not None
and server.status.url is not None
and gc.store_configuration.url == server.status.url
)
def connect_to_server(
self,
) -> None:
"""Connect to the local ZenML server instance.
Raises:
ServerDeploymentError: If the local ZenML server is not running or
is unreachable.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server()
provider_name = server.config.provider.value
gc = GlobalConfiguration()
if not server.status or not server.status.url:
raise ServerDeploymentError(
f"The local {provider_name} ZenML server is not currently "
"running or is unreachable."
)
store_config = RestZenStoreConfiguration(
url=server.status.url,
)
if gc.store_configuration == store_config:
logger.info(
"Your client is already connected to the local "
f"{provider_name} ZenML server."
)
return
logger.info(
f"Connecting to the local {provider_name} ZenML server "
f"({store_config.url})."
)
gc.set_store(store_config)
logger.info(
f"Connected to the local {provider_name} ZenML server "
f"({store_config.url})."
)
def disconnect_from_server(
self,
) -> None:
"""Disconnect from the ZenML server instance."""
gc = GlobalConfiguration()
store_cfg = gc.store_configuration
if store_cfg.type != StoreType.REST:
logger.info(
"Your client is not currently connected to a ZenML server."
)
return
logger.info(
f"Disconnecting from the local ({store_cfg.url}) ZenML server."
)
gc.set_default_store()
logger.info("Disconnected from the local ZenML server.")
def get_server(
self,
) -> LocalServerDeployment:
"""Get the local server deployment.
Returns:
The local server deployment.
Raises:
ServerDeploymentNotFoundError: If no local server deployment is
found.
"""
for provider in self._providers.values():
try:
return provider.get_server(
LocalServerDeploymentConfig(provider=provider.TYPE)
)
except ServerDeploymentNotFoundError:
pass
raise ServerDeploymentNotFoundError(
"No local server deployment was found."
)
def get_server_logs(
self,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs for the local ZenML server.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server()
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
logger.info(
f"Fetching logs from the local {provider_name} ZenML server..."
)
return provider.get_server_logs(
server.config, follow=follow, tail=tail
)
connect_to_server(self)
Connect to the local ZenML server instance.
Exceptions:
Type | Description |
---|---|
ServerDeploymentError |
If the local ZenML server is not running or is unreachable. |
Source code in zenml/zen_server/deploy/deployer.py
def connect_to_server(
self,
) -> None:
"""Connect to the local ZenML server instance.
Raises:
ServerDeploymentError: If the local ZenML server is not running or
is unreachable.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server()
provider_name = server.config.provider.value
gc = GlobalConfiguration()
if not server.status or not server.status.url:
raise ServerDeploymentError(
f"The local {provider_name} ZenML server is not currently "
"running or is unreachable."
)
store_config = RestZenStoreConfiguration(
url=server.status.url,
)
if gc.store_configuration == store_config:
logger.info(
"Your client is already connected to the local "
f"{provider_name} ZenML server."
)
return
logger.info(
f"Connecting to the local {provider_name} ZenML server "
f"({store_config.url})."
)
gc.set_store(store_config)
logger.info(
f"Connected to the local {provider_name} ZenML server "
f"({store_config.url})."
)
deploy_server(self, config, timeout=None, restart=False)
Deploy the local ZenML server or update the existing deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successful. If not supplied, the default timeout value specified by the provider is used. |
None |
restart |
bool |
If True, the existing server deployment will be torn down and a new server will be deployed. |
False |
Returns:
Type | Description |
---|---|
LocalServerDeployment |
The local server deployment. |
Source code in zenml/zen_server/deploy/deployer.py
def deploy_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
restart: bool = False,
) -> LocalServerDeployment:
"""Deploy the local ZenML server or update the existing deployment.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
restart: If True, the existing server deployment will be torn down
and a new server will be deployed.
Returns:
The local server deployment.
"""
# Ensure that the local database is always initialized before any local
# server is deployed or updated.
self.initialize_local_database()
try:
self.get_server()
except ServerDeploymentNotFoundError:
pass
else:
return self.update_server(
config=config, timeout=timeout, restart=restart
)
provider_name = config.provider.value
provider = self.get_provider(config.provider)
logger.info(f"Deploying a local {provider_name} ZenML server.")
return provider.deploy_server(config, timeout=timeout)
disconnect_from_server(self)
Disconnect from the ZenML server instance.
Source code in zenml/zen_server/deploy/deployer.py
def disconnect_from_server(
self,
) -> None:
"""Disconnect from the ZenML server instance."""
gc = GlobalConfiguration()
store_cfg = gc.store_configuration
if store_cfg.type != StoreType.REST:
logger.info(
"Your client is not currently connected to a ZenML server."
)
return
logger.info(
f"Disconnecting from the local ({store_cfg.url}) ZenML server."
)
gc.set_default_store()
logger.info("Disconnected from the local ZenML server.")
get_provider(provider_type)
classmethod
Get the server provider associated with a provider type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
provider_type |
ServerProviderType |
The server provider type. |
required |
Returns:
Type | Description |
---|---|
BaseServerProvider |
The server provider associated with the provider type. |
Exceptions:
Type | Description |
---|---|
ServerProviderNotFoundError |
If no provider is registered for the given provider type. |
Source code in zenml/zen_server/deploy/deployer.py
@classmethod
def get_provider(
cls, provider_type: ServerProviderType
) -> BaseServerProvider:
"""Get the server provider associated with a provider type.
Args:
provider_type: The server provider type.
Returns:
The server provider associated with the provider type.
Raises:
ServerProviderNotFoundError: If no provider is registered for the
given provider type.
"""
if provider_type not in cls._providers:
raise ServerProviderNotFoundError(
f"Server provider '{provider_type}' is not registered."
)
return cls._providers[provider_type]
get_server(self)
Get the local server deployment.
Returns:
Type | Description |
---|---|
LocalServerDeployment |
The local server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If no local server deployment is found. |
Source code in zenml/zen_server/deploy/deployer.py
def get_server(
self,
) -> LocalServerDeployment:
"""Get the local server deployment.
Returns:
The local server deployment.
Raises:
ServerDeploymentNotFoundError: If no local server deployment is
found.
"""
for provider in self._providers.values():
try:
return provider.get_server(
LocalServerDeploymentConfig(provider=provider.TYPE)
)
except ServerDeploymentNotFoundError:
pass
raise ServerDeploymentNotFoundError(
"No local server deployment was found."
)
get_server_logs(self, follow=False, tail=None)
Retrieve the logs for the local ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/zen_server/deploy/deployer.py
def get_server_logs(
self,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs for the local ZenML server.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server()
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
logger.info(
f"Fetching logs from the local {provider_name} ZenML server..."
)
return provider.get_server_logs(
server.config, follow=follow, tail=tail
)
initialize_local_database(self)
Initialize the local ZenML database.
Source code in zenml/zen_server/deploy/deployer.py
def initialize_local_database(self) -> None:
"""Initialize the local ZenML database."""
default_store_cfg = GlobalConfiguration().get_default_store()
BaseZenStore.create_store(default_store_cfg)
is_connected_to_server(self)
Check if the ZenML client is currently connected to the local ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if the ZenML client is connected to the local ZenML server, False otherwise. |
Source code in zenml/zen_server/deploy/deployer.py
def is_connected_to_server(self) -> bool:
"""Check if the ZenML client is currently connected to the local ZenML server.
Returns:
True if the ZenML client is connected to the local ZenML server, False
otherwise.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
try:
server = self.get_server()
except ServerDeploymentNotFoundError:
return False
gc = GlobalConfiguration()
return (
server.status is not None
and server.status.url is not None
and gc.store_configuration.url == server.status.url
)
register_provider(provider)
classmethod
Register a server provider.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
provider |
Type[zenml.zen_server.deploy.base_provider.BaseServerProvider] |
The server provider to register. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If a provider with the same type is already registered. |
Source code in zenml/zen_server/deploy/deployer.py
@classmethod
def register_provider(cls, provider: Type[BaseServerProvider]) -> None:
"""Register a server provider.
Args:
provider: The server provider to register.
Raises:
TypeError: If a provider with the same type is already registered.
"""
if provider.TYPE in cls._providers:
raise TypeError(
f"Server provider '{provider.TYPE}' is already registered."
)
logger.debug(f"Registering server provider '{provider.TYPE}'.")
cls._providers[provider.TYPE] = provider()
remove_server(self, timeout=None)
Tears down and removes all resources and files associated with the local ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successfully torn down. If not supplied, a provider specific default timeout value is used. |
None |
Source code in zenml/zen_server/deploy/deployer.py
def remove_server(
self,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with the local ZenML server deployment.
Args:
timeout: The timeout in seconds to wait until the deployment is
successfully torn down. If not supplied, a provider specific
default timeout value is used.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
try:
server = self.get_server()
except ServerDeploymentNotFoundError:
return
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
if self.is_connected_to_server():
try:
self.disconnect_from_server()
except Exception as e:
logger.warning(
f"Failed to disconnect from the local server: {e}"
)
logger.info(f"Tearing down the local {provider_name} ZenML server.")
provider.remove_server(server.config, timeout=timeout)
update_server(self, config, timeout=None, restart=False)
Update an existing local ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
LocalServerDeploymentConfig |
The new server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successful. If not supplied, a default timeout value of 30 seconds is used. |
None |
restart |
bool |
If True, the existing server deployment will be torn down and a new server will be deployed. |
False |
Returns:
Type | Description |
---|---|
LocalServerDeployment |
The updated server deployment. |
Source code in zenml/zen_server/deploy/deployer.py
def update_server(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
restart: bool = False,
) -> LocalServerDeployment:
"""Update an existing local ZenML server deployment.
Args:
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, a default timeout value of 30
seconds is used.
restart: If True, the existing server deployment will be torn down
and a new server will be deployed.
Returns:
The updated server deployment.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
existing_server = self.get_server()
provider = self.get_provider(config.provider)
if existing_server.config.provider != config.provider or restart:
existing_provider = self.get_provider(
existing_server.config.provider
)
# Tear down the existing server deployment
existing_provider.remove_server(
existing_server.config, timeout=timeout
)
# Deploy a new server with the new provider
return provider.deploy_server(config, timeout=timeout)
return provider.update_server(config, timeout=timeout)
deployment
Zen Server deployment definitions.
LocalServerDeployment (BaseModel)
Server deployment.
Attributes:
Name | Type | Description |
---|---|---|
config |
LocalServerDeploymentConfig |
The server deployment configuration. |
status |
Optional[zenml.zen_server.deploy.deployment.LocalServerDeploymentStatus] |
The server deployment status. |
Source code in zenml/zen_server/deploy/deployment.py
class LocalServerDeployment(BaseModel):
"""Server deployment.
Attributes:
config: The server deployment configuration.
status: The server deployment status.
"""
config: LocalServerDeploymentConfig
status: Optional[LocalServerDeploymentStatus] = None
@property
def is_running(self) -> bool:
"""Check if the server is running.
Returns:
Whether the server is running.
"""
return (
self.status is not None
and self.status.status == ServiceState.ACTIVE
)
is_running: bool
property
readonly
Check if the server is running.
Returns:
Type | Description |
---|---|
bool |
Whether the server is running. |
LocalServerDeploymentConfig (BaseModel)
Generic local server deployment configuration.
All local server deployment configurations should inherit from this class and handle extra attributes as provider specific attributes.
Attributes:
Name | Type | Description |
---|---|---|
provider |
ServerProviderType |
The server provider type. |
Source code in zenml/zen_server/deploy/deployment.py
class LocalServerDeploymentConfig(BaseModel):
"""Generic local server deployment configuration.
All local server deployment configurations should inherit from this class
and handle extra attributes as provider specific attributes.
Attributes:
provider: The server provider type.
"""
provider: ServerProviderType
@property
def url(self) -> Optional[str]:
"""Get the configured server URL.
Returns:
The configured server URL.
"""
return None
model_config = ConfigDict(
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment=True,
# Allow extra attributes to be set in the base class. The concrete
# classes are responsible for validating the attributes.
extra="allow",
)
url: Optional[str]
property
readonly
Get the configured server URL.
Returns:
Type | Description |
---|---|
Optional[str] |
The configured server URL. |
LocalServerDeploymentStatus (BaseModel)
Local server deployment status.
Ideally this should convey the following information:
- whether the server's deployment is managed by this client (i.e. if
the server was deployed with
zenml login --local
) - for a managed deployment, the status of the deployment/tear-down, e.g. not deployed, deploying, running, deleting, deployment timeout/error, tear-down timeout/error etc.
- for an unmanaged deployment, the operational status (i.e. whether the server is reachable)
- the URL of the server
Attributes:
Name | Type | Description |
---|---|---|
status |
ServiceState |
The status of the server deployment. |
status_message |
Optional[str] |
A message describing the last status. |
connected |
bool |
Whether the client is currently connected to this server. |
url |
Optional[str] |
The URL of the server. |
Source code in zenml/zen_server/deploy/deployment.py
class LocalServerDeploymentStatus(BaseModel):
"""Local server deployment status.
Ideally this should convey the following information:
* whether the server's deployment is managed by this client (i.e. if
the server was deployed with `zenml login --local`)
* for a managed deployment, the status of the deployment/tear-down, e.g.
not deployed, deploying, running, deleting, deployment timeout/error,
tear-down timeout/error etc.
* for an unmanaged deployment, the operational status (i.e. whether the
server is reachable)
* the URL of the server
Attributes:
status: The status of the server deployment.
status_message: A message describing the last status.
connected: Whether the client is currently connected to this server.
url: The URL of the server.
"""
status: ServiceState
status_message: Optional[str] = None
connected: bool
url: Optional[str] = None
ca_crt: Optional[str] = None
docker
special
ZenML Server Docker Deployment.
docker_provider
Zen Server docker deployer implementation.
DockerServerProvider (BaseServerProvider)
Docker ZenML server provider.
Source code in zenml/zen_server/deploy/docker/docker_provider.py
class DockerServerProvider(BaseServerProvider):
"""Docker ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.DOCKER
CONFIG_TYPE: ClassVar[Type[LocalServerDeploymentConfig]] = (
DockerServerDeploymentConfig
)
@classmethod
def _get_service_configuration(
cls,
server_config: LocalServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service, service endpoint and endpoint monitor configuration.
"""
assert isinstance(server_config, DockerServerDeploymentConfig)
return (
DockerZenServerConfig(
root_runtime_path=DockerZenServer.config_path(),
singleton=True,
image=server_config.image,
name=ServerProviderType.DOCKER.value,
server=server_config,
),
ContainerServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
port=server_config.port,
allocate_port=False,
),
HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
use_head_request=True,
),
)
def _create_service(
self,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return the docker ZenML server deployment service.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The service instance.
Raises:
RuntimeError: If a docker service is already running.
"""
assert isinstance(config, DockerServerDeploymentConfig)
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service = DockerZenServer.get_service()
existing_service = DockerZenServer.get_service()
if existing_service:
raise RuntimeError(
f"A docker ZenML server with name '{existing_service.config.name}' "
f"is already running. Please stop it first before starting a "
f"new one."
)
(
service_config,
endpoint_cfg,
monitor_cfg,
) = self._get_service_configuration(config)
endpoint = ContainerServiceEndpoint(
config=endpoint_cfg,
monitor=HTTPEndpointHealthMonitor(
config=monitor_cfg,
),
)
service = DockerZenServer(
uuid=uuid4(), config=service_config, endpoint=endpoint
)
service.start(timeout=timeout)
return service
def _update_service(
self,
service: BaseService,
config: LocalServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update the docker ZenML server deployment service.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
(
new_config,
new_endpoint_cfg,
new_monitor_cfg,
) = self._get_service_configuration(config)
assert service.endpoint
assert service.endpoint.monitor
service.stop(timeout=timeout)
(
service.config,
service.endpoint.config,
service.endpoint.monitor.config,
) = (
new_config,
new_endpoint_cfg,
new_monitor_cfg,
)
service.start(timeout=timeout)
return service
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start the docker ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service.start(timeout=timeout)
return service
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop the docker ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout=timeout)
return service
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove the docker ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed.
"""
assert isinstance(service, DockerZenServer)
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout)
shutil.rmtree(DockerZenServer.config_path())
def _get_service(self) -> BaseService:
"""Get the docker ZenML server deployment service.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
service = DockerZenServer.get_service()
if service is None:
raise KeyError("The docker ZenML server is not deployed.")
return service
def _get_deployment_config(
self, service: BaseService
) -> LocalServerDeploymentConfig:
"""Recreate the server deployment configuration from a service instance.
Args:
service: The service instance.
Returns:
The server deployment configuration.
"""
server = cast(DockerZenServer, service)
return server.config.server
CONFIG_TYPE (LocalServerDeploymentConfig)
Docker server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
image |
str |
The Docker image to use for the server. |
Source code in zenml/zen_server/deploy/docker/docker_provider.py
class DockerServerDeploymentConfig(LocalServerDeploymentConfig):
"""Docker server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
image: The Docker image to use for the server.
"""
port: int = 8238
image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE
store: Optional[StoreConfiguration] = None
@property
def url(self) -> Optional[str]:
"""Get the configured server URL.
Returns:
The configured server URL.
"""
return f"http://{DEFAULT_LOCAL_SERVICE_IP_ADDRESS}:{self.port}"
model_config = ConfigDict(extra="forbid")
url: Optional[str]
property
readonly
Get the configured server URL.
Returns:
Type | Description |
---|---|
Optional[str] |
The configured server URL. |
docker_zen_server
Service implementation for the ZenML docker server deployment.
DockerServerDeploymentConfig (LocalServerDeploymentConfig)
Docker server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
image |
str |
The Docker image to use for the server. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerServerDeploymentConfig(LocalServerDeploymentConfig):
"""Docker server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
image: The Docker image to use for the server.
"""
port: int = 8238
image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE
store: Optional[StoreConfiguration] = None
@property
def url(self) -> Optional[str]:
"""Get the configured server URL.
Returns:
The configured server URL.
"""
return f"http://{DEFAULT_LOCAL_SERVICE_IP_ADDRESS}:{self.port}"
model_config = ConfigDict(extra="forbid")
url: Optional[str]
property
readonly
Get the configured server URL.
Returns:
Type | Description |
---|---|
Optional[str] |
The configured server URL. |
DockerZenServer (ContainerService)
Service that can be used to start a docker ZenServer.
Attributes:
Name | Type | Description |
---|---|---|
config |
DockerZenServerConfig |
service configuration |
endpoint |
ContainerServiceEndpoint |
service endpoint |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerZenServer(ContainerService):
"""Service that can be used to start a docker ZenServer.
Attributes:
config: service configuration
endpoint: service endpoint
"""
SERVICE_TYPE = ServiceType(
name="docker_zenml_server",
type="zen_server",
flavor="docker",
description="Docker ZenML server deployment",
)
config: DockerZenServerConfig
endpoint: ContainerServiceEndpoint
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the docker ZenML server files are located.
Returns:
Path to the docker ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"docker",
)
@property
def _global_config_path(self) -> str:
"""Path to the global configuration directory used by this server.
Returns:
Path to the global configuration directory used by this server.
"""
return os.path.join(
self.config_path(), SERVICE_CONTAINER_GLOBAL_CONFIG_DIR
)
@classmethod
def get_service(cls) -> Optional["DockerZenServer"]:
"""Load and return the docker ZenML server service, if present.
Returns:
The docker ZenML server service or None, if the docker server
deployment is not found.
"""
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
"DockerZenServer", DockerZenServer.from_json(f.read())
)
except FileNotFoundError:
return None
def _get_container_cmd(self) -> Tuple[List[str], Dict[str, str]]:
"""Get the command to run the service container.
Override the inherited method to use a ZenML global config path inside
the container that points to the global config copy instead of the
one mounted from the local host.
Returns:
Command needed to launch the docker container and the environment
variables to set, in the formats accepted by subprocess.Popen.
"""
gc = GlobalConfiguration()
cmd, env = super()._get_container_cmd()
env[ENV_ZENML_SERVER] = "true"
env[ENV_ZENML_CONFIG_PATH] = os.path.join(
SERVICE_CONTAINER_PATH,
SERVICE_CONTAINER_GLOBAL_CONFIG_DIR,
)
env[ENV_ZENML_SERVER_AUTH_SCHEME] = AuthScheme.NO_AUTH.value
env[ENV_ZENML_SERVER_DEPLOYMENT_TYPE] = ServerDeploymentType.DOCKER
env[ENV_ZENML_ANALYTICS_OPT_IN] = str(gc.analytics_opt_in)
env[ENV_ZENML_USER_ID] = str(gc.user_id)
# Set the local stores path to the same path used by the client (mounted
# in the container by the super class). This ensures that the server's
# default store configuration is initialized to point at the same local
# SQLite database as the client.
env[ENV_ZENML_LOCAL_STORES_PATH] = os.path.join(
SERVICE_CONTAINER_GLOBAL_CONFIG_PATH,
LOCAL_STORES_DIRECTORY_NAME,
)
env[ENV_ZENML_DISABLE_DATABASE_MIGRATION] = "True"
env[ENV_ZENML_SERVER_AUTO_ACTIVATE] = "True"
return cmd, env
def provision(self) -> None:
"""Provision the service."""
super().provision()
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store_configuration.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host="0.0.0.0", # nosec
port=self.endpoint.config.port or 8000,
log_level="info",
server_header=False,
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
config_path()
classmethod
Path to the directory where the docker ZenML server files are located.
Returns:
Type | Description |
---|---|
str |
Path to the docker ZenML server runtime directory. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the docker ZenML server files are located.
Returns:
Path to the docker ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"docker",
)
get_service()
classmethod
Load and return the docker ZenML server service, if present.
Returns:
Type | Description |
---|---|
Optional[DockerZenServer] |
The docker ZenML server service or None, if the docker server deployment is not found. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
@classmethod
def get_service(cls) -> Optional["DockerZenServer"]:
"""Load and return the docker ZenML server service, if present.
Returns:
The docker ZenML server service or None, if the docker server
deployment is not found.
"""
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
"DockerZenServer", DockerZenServer.from_json(f.read())
)
except FileNotFoundError:
return None
model_post_init(/, self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, context)
original_model_post_init(self, context)
provision(self)
Provision the service.
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def provision(self) -> None:
"""Provision the service."""
super().provision()
run(self)
Run the ZenML Server.
Exceptions:
Type | Description |
---|---|
ValueError |
if started with a global configuration that connects to another ZenML server. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store_configuration.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host="0.0.0.0", # nosec
port=self.endpoint.config.port or 8000,
log_level="info",
server_header=False,
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
DockerZenServerConfig (ContainerServiceConfig)
Docker Zen server configuration.
Attributes:
Name | Type | Description |
---|---|---|
server |
DockerServerDeploymentConfig |
The deployment configuration. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerZenServerConfig(ContainerServiceConfig):
"""Docker Zen server configuration.
Attributes:
server: The deployment configuration.
"""
server: DockerServerDeploymentConfig
exceptions
ZenML server deployment exceptions.
ServerDeploymentConfigurationError (ServerDeploymentError)
Raised when there is a ZenML server deployment configuration error .
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentConfigurationError(ServerDeploymentError):
"""Raised when there is a ZenML server deployment configuration error ."""
ServerDeploymentError (ZenMLBaseException)
Base exception class for all ZenML server deployment related errors.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentError(ZenMLBaseException):
"""Base exception class for all ZenML server deployment related errors."""
ServerDeploymentExistsError (ServerDeploymentError)
Raised when trying to deploy a new ZenML server with the same name.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentExistsError(ServerDeploymentError):
"""Raised when trying to deploy a new ZenML server with the same name."""
ServerDeploymentNotFoundError (ServerDeploymentError)
Raised when trying to fetch a ZenML server deployment that doesn't exist.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentNotFoundError(ServerDeploymentError):
"""Raised when trying to fetch a ZenML server deployment that doesn't exist."""
ServerProviderNotFoundError (ServerDeploymentError)
Raised when using a ZenML server provider that doesn't exist.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerProviderNotFoundError(ServerDeploymentError):
"""Raised when using a ZenML server provider that doesn't exist."""
exceptions
REST API exception handling.
ErrorModel (BaseModel)
Base class for error responses.
Source code in zenml/zen_server/exceptions.py
class ErrorModel(BaseModel):
"""Base class for error responses."""
detail: Optional[Any] = None
error_detail(error, exception_type=None)
Convert an Exception to API representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error |
Exception |
Exception to convert. |
required |
exception_type |
Optional[Type[Exception]] |
Exception type to use in the error response instead of the type of the supplied exception. This is useful when the raised exception is a subclass of an exception type that is properly handled by the REST API. |
None |
Returns:
Type | Description |
---|---|
List[str] |
List of strings representing the error. |
Source code in zenml/zen_server/exceptions.py
def error_detail(
error: Exception, exception_type: Optional[Type[Exception]] = None
) -> List[str]:
"""Convert an Exception to API representation.
Args:
error: Exception to convert.
exception_type: Exception type to use in the error response instead of
the type of the supplied exception. This is useful when the raised
exception is a subclass of an exception type that is properly
handled by the REST API.
Returns:
List of strings representing the error.
"""
class_name = (
exception_type.__name__ if exception_type else type(error).__name__
)
return [class_name, str(error)]
exception_from_response(response)
Convert an error HTTP response to an exception.
Uses the REST_API_EXCEPTIONS list to determine the appropriate exception class to use based on the response status code and the exception class name embedded in the response body.
The last entry in the list of exceptions associated with a status code is used as a fallback if the exception class name in the response body is not found in the list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
response |
Response |
HTTP error response to convert. |
required |
Returns:
Type | Description |
---|---|
Optional[Exception] |
Exception with the appropriate type and arguments, or None if the response does not contain an error or the response cannot be unpacked into an exception. |
Source code in zenml/zen_server/exceptions.py
def exception_from_response(
response: requests.Response,
) -> Optional[Exception]:
"""Convert an error HTTP response to an exception.
Uses the REST_API_EXCEPTIONS list to determine the appropriate exception
class to use based on the response status code and the exception class name
embedded in the response body.
The last entry in the list of exceptions associated with a status code is
used as a fallback if the exception class name in the response body is not
found in the list.
Args:
response: HTTP error response to convert.
Returns:
Exception with the appropriate type and arguments, or None if the
response does not contain an error or the response cannot be unpacked
into an exception.
"""
def unpack_exc() -> Tuple[Optional[str], str]:
"""Unpack the response body into an exception name and message.
Returns:
Tuple of exception name and message.
"""
try:
response_json = response.json()
except requests.exceptions.JSONDecodeError:
return None, response.text
if isinstance(response_json, dict):
detail = response_json.get("detail", response.text)
else:
detail = response_json
# The detail can also be a single string
if isinstance(detail, str):
return None, detail
# The detail should be a list of strings encoding the exception
# class name and the exception message
if not isinstance(detail, list):
return None, response.text
# First detail item is the exception class name
if len(detail) < 1 or not isinstance(detail[0], str):
return None, response.text
# Remaining detail items are the exception arguments
message = ": ".join([str(arg) for arg in detail[1:]])
return detail[0], message
exc_name, exc_msg = unpack_exc()
default_exc: Optional[Type[Exception]] = None
for exception, status_code in REST_API_EXCEPTIONS:
if response.status_code != status_code:
continue
default_exc = exception
if exc_name == exception.__name__:
# An entry was found that is an exact match for both the status
# code and the exception class name.
break
else:
# The exception class name extracted from the response body was not
# found in the list of exceptions associated with the status code, so
# use the last entry as a fallback.
if default_exc is None:
return None
exception = default_exc
# There is one special case where we want to return a specific exception:
# 401 Unauthorized exceptions thrown directly by FastAPI in the course of
# authentication are interpreted as AuthorizationException, but we want to
# return CredentialsNotValid instead.
if response.status_code == 401:
if not isinstance(exception(), CredentialsNotValid):
if response.headers.get("WWW-Authenticate"):
return CredentialsNotValid(exc_msg)
return exception(exc_msg)
http_exception_from_error(error)
Convert an Exception to a HTTP error response.
Uses the REST_API_EXCEPTIONS list to determine the appropriate status code associated with the exception type. The exception class name and arguments are embedded in the HTTP error response body.
The lookup uses the first occurrence of the exception type in the list. If
the exception type is not found in the list, the lookup uses isinstance
to determine the most specific exception type corresponding to the supplied
exception. This allows users to call this method with exception types that
are not directly listed in the REST_API_EXCEPTIONS list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error |
Exception |
Exception to convert. |
required |
Returns:
Type | Description |
---|---|
HTTPException |
HTTPException with the appropriate status code and error detail. |
Source code in zenml/zen_server/exceptions.py
def http_exception_from_error(error: Exception) -> "HTTPException":
"""Convert an Exception to a HTTP error response.
Uses the REST_API_EXCEPTIONS list to determine the appropriate status code
associated with the exception type. The exception class name and arguments
are embedded in the HTTP error response body.
The lookup uses the first occurrence of the exception type in the list. If
the exception type is not found in the list, the lookup uses `isinstance`
to determine the most specific exception type corresponding to the supplied
exception. This allows users to call this method with exception types that
are not directly listed in the REST_API_EXCEPTIONS list.
Args:
error: Exception to convert.
Returns:
HTTPException with the appropriate status code and error detail.
"""
from fastapi import HTTPException
status_code = 0
matching_exception_type: Optional[Type[Exception]] = None
for exception_type, exc_status_code in REST_API_EXCEPTIONS:
if error.__class__ is exception_type:
# Found an exact match
matching_exception_type = exception_type
status_code = exc_status_code
break
if isinstance(error, exception_type):
# Found a matching exception
if not matching_exception_type:
# This is the first matching exception, so keep it
matching_exception_type = exception_type
status_code = exc_status_code
continue
# This is not the first matching exception, so check if it is more
# specific than the previous matching exception
if issubclass(
exception_type,
matching_exception_type,
):
matching_exception_type = exception_type
status_code = exc_status_code
# When the matching exception is not found in the list, a 500 Internal
# Server Error is returned
status_code = status_code or 500
matching_exception_type = matching_exception_type or RuntimeError
return HTTPException(
status_code=status_code,
detail=error_detail(error, matching_exception_type),
)
feature_gate
special
endpoint_utils
All endpoint utils for the feature gate implementations.
check_entitlement(resource_type)
Queries the feature gate to see if the operation falls within the tenants entitlements.
Raises an exception if the user is not entitled to create an instance of the resource. Otherwise, simply returns.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
ResourceType |
The type of resource to check for. |
required |
Source code in zenml/zen_server/feature_gate/endpoint_utils.py
def check_entitlement(resource_type: ResourceType) -> None:
"""Queries the feature gate to see if the operation falls within the tenants entitlements.
Raises an exception if the user is not entitled to create an instance of the
resource. Otherwise, simply returns.
Args:
resource_type: The type of resource to check for.
"""
if not server_config().feature_gate_enabled:
return
return feature_gate().check_entitlement(resource=resource_type)
report_decrement(resource_type, resource_id)
Reports the deletion/deactivation of a feature/resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
ResourceType |
The type of resource to report a decrement in count for. |
required |
resource_id |
UUID |
ID of the resource that was deleted. |
required |
Source code in zenml/zen_server/feature_gate/endpoint_utils.py
def report_decrement(resource_type: ResourceType, resource_id: UUID) -> None:
"""Reports the deletion/deactivation of a feature/resource.
Args:
resource_type: The type of resource to report a decrement in count for.
resource_id: ID of the resource that was deleted.
"""
if not server_config().feature_gate_enabled:
return
feature_gate().report_event(
resource=resource_type, resource_id=resource_id, is_decrement=True
)
report_usage(resource_type, resource_id)
Reports the creation/usage of a feature/resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
ResourceType |
The type of resource to report a usage for |
required |
resource_id |
UUID |
ID of the resource that was created. |
required |
Source code in zenml/zen_server/feature_gate/endpoint_utils.py
def report_usage(resource_type: ResourceType, resource_id: UUID) -> None:
"""Reports the creation/usage of a feature/resource.
Args:
resource_type: The type of resource to report a usage for
resource_id: ID of the resource that was created.
"""
if not server_config().feature_gate_enabled:
return
feature_gate().report_event(
resource=resource_type, resource_id=resource_id
)
feature_gate_interface
Definition of the feature gate interface.
FeatureGateInterface (ABC)
RBAC interface definition.
Source code in zenml/zen_server/feature_gate/feature_gate_interface.py
class FeatureGateInterface(ABC):
"""RBAC interface definition."""
@abstractmethod
def check_entitlement(self, resource: ResourceType) -> None:
"""Checks if a user is entitled to create a resource.
Args:
resource: The resource the user wants to create
Raises:
UpgradeRequiredError in case a subscription limit is reached
"""
@abstractmethod
def report_event(
self,
resource: ResourceType,
resource_id: UUID,
is_decrement: bool = False,
) -> None:
"""Reports the usage of a feature to the aggregator backend.
Args:
resource: The resource the user created
resource_id: ID of the resource that was created/deleted.
is_decrement: In case this event reports an actual decrement of usage
"""
check_entitlement(self, resource)
Checks if a user is entitled to create a resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource |
ResourceType |
The resource the user wants to create |
required |
Source code in zenml/zen_server/feature_gate/feature_gate_interface.py
@abstractmethod
def check_entitlement(self, resource: ResourceType) -> None:
"""Checks if a user is entitled to create a resource.
Args:
resource: The resource the user wants to create
Raises:
UpgradeRequiredError in case a subscription limit is reached
"""
report_event(self, resource, resource_id, is_decrement=False)
Reports the usage of a feature to the aggregator backend.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource |
ResourceType |
The resource the user created |
required |
resource_id |
UUID |
ID of the resource that was created/deleted. |
required |
is_decrement |
bool |
In case this event reports an actual decrement of usage |
False |
Source code in zenml/zen_server/feature_gate/feature_gate_interface.py
@abstractmethod
def report_event(
self,
resource: ResourceType,
resource_id: UUID,
is_decrement: bool = False,
) -> None:
"""Reports the usage of a feature to the aggregator backend.
Args:
resource: The resource the user created
resource_id: ID of the resource that was created/deleted.
is_decrement: In case this event reports an actual decrement of usage
"""
zenml_cloud_feature_gate
ZenML Pro implementation of the feature gate.
RawUsageEvent (BaseModel)
Model for reporting raw usage of a feature.
In case of consumables the UsageReport allows the Pricing Backend to increment the usage per time-frame by 1.
Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
class RawUsageEvent(BaseModel):
"""Model for reporting raw usage of a feature.
In case of consumables the UsageReport allows the Pricing Backend to
increment the usage per time-frame by 1.
"""
organization_id: str = Field(
description="The organization that this usage can be attributed to.",
)
feature: ResourceType = Field(
description="The feature whose usage is being reported.",
)
total: int = Field(
description="The total amount of entities of this type."
)
metadata: Dict[str, Any] = Field(
default={},
description="Allows attaching additional metadata to events.",
)
ZenMLCloudFeatureGateInterface (FeatureGateInterface)
ZenML Cloud Feature Gate implementation.
Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
class ZenMLCloudFeatureGateInterface(FeatureGateInterface):
"""ZenML Cloud Feature Gate implementation."""
def __init__(self) -> None:
"""Initialize the object."""
self._connection = cloud_connection()
def check_entitlement(self, resource: ResourceType) -> None:
"""Checks if a user is entitled to create a resource.
Args:
resource: The resource the user wants to create
Raises:
SubscriptionUpgradeRequiredError: in case a subscription limit is reached
"""
try:
response = self._connection.get(
endpoint=ENTITLEMENT_ENDPOINT + "/" + resource, params=None
)
except SubscriptionUpgradeRequiredError:
raise SubscriptionUpgradeRequiredError(
f"Your subscription reached its `{resource}` limit. Please "
f"upgrade your subscription or reach out to us."
)
if response.status_code != 200:
logger.warning(
"Unexpected response status code from entitlement "
f"endpoint: {response.status_code}. Message: "
f"{response.json()}"
)
def report_event(
self,
resource: ResourceType,
resource_id: UUID,
is_decrement: bool = False,
) -> None:
"""Reports the usage of a feature to the aggregator backend.
Args:
resource: The resource the user created
resource_id: ID of the resource that was created/deleted.
is_decrement: In case this event reports an actual decrement of usage
"""
data = RawUsageEvent(
organization_id=ORGANIZATION_ID,
feature=resource,
total=1 if not is_decrement else -1,
metadata={
"tenant_id": str(server_config.get_external_server_id()),
"resource_id": str(resource_id),
},
).model_dump()
response = self._connection.post(
endpoint=USAGE_EVENT_ENDPOINT, data=data
)
if response.status_code != 200:
logger.error(
"Usage report not accepted by upstream backend. "
f"Status Code: {response.status_code}, Message: "
f"{response.json()}."
)
__init__(self)
special
Initialize the object.
Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
def __init__(self) -> None:
"""Initialize the object."""
self._connection = cloud_connection()
check_entitlement(self, resource)
Checks if a user is entitled to create a resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource |
ResourceType |
The resource the user wants to create |
required |
Exceptions:
Type | Description |
---|---|
SubscriptionUpgradeRequiredError |
in case a subscription limit is reached |
Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
def check_entitlement(self, resource: ResourceType) -> None:
"""Checks if a user is entitled to create a resource.
Args:
resource: The resource the user wants to create
Raises:
SubscriptionUpgradeRequiredError: in case a subscription limit is reached
"""
try:
response = self._connection.get(
endpoint=ENTITLEMENT_ENDPOINT + "/" + resource, params=None
)
except SubscriptionUpgradeRequiredError:
raise SubscriptionUpgradeRequiredError(
f"Your subscription reached its `{resource}` limit. Please "
f"upgrade your subscription or reach out to us."
)
if response.status_code != 200:
logger.warning(
"Unexpected response status code from entitlement "
f"endpoint: {response.status_code}. Message: "
f"{response.json()}"
)
report_event(self, resource, resource_id, is_decrement=False)
Reports the usage of a feature to the aggregator backend.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource |
ResourceType |
The resource the user created |
required |
resource_id |
UUID |
ID of the resource that was created/deleted. |
required |
is_decrement |
bool |
In case this event reports an actual decrement of usage |
False |
Source code in zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py
def report_event(
self,
resource: ResourceType,
resource_id: UUID,
is_decrement: bool = False,
) -> None:
"""Reports the usage of a feature to the aggregator backend.
Args:
resource: The resource the user created
resource_id: ID of the resource that was created/deleted.
is_decrement: In case this event reports an actual decrement of usage
"""
data = RawUsageEvent(
organization_id=ORGANIZATION_ID,
feature=resource,
total=1 if not is_decrement else -1,
metadata={
"tenant_id": str(server_config.get_external_server_id()),
"resource_id": str(resource_id),
},
).model_dump()
response = self._connection.post(
endpoint=USAGE_EVENT_ENDPOINT, data=data
)
if response.status_code != 200:
logger.error(
"Usage report not accepted by upstream backend. "
f"Status Code: {response.status_code}, Message: "
f"{response.json()}."
)
jwt
Authentication module for ZenML server.
JWTToken (BaseModel)
Pydantic object representing a JWT token.
Attributes:
Name | Type | Description |
---|---|---|
user_id |
UUID |
The id of the authenticated User. |
device_id |
Optional[uuid.UUID] |
The id of the authenticated device. |
api_key_id |
Optional[uuid.UUID] |
The id of the authenticated API key for which this token was issued. |
schedule_id |
Optional[uuid.UUID] |
The id of the schedule for which the token was issued. |
pipeline_run_id |
Optional[uuid.UUID] |
The id of the pipeline run for which the token was issued. |
step_run_id |
Optional[uuid.UUID] |
The id of the step run for which the token was issued. |
claims |
Dict[str, Any] |
The original token claims. |
Source code in zenml/zen_server/jwt.py
class JWTToken(BaseModel):
"""Pydantic object representing a JWT token.
Attributes:
user_id: The id of the authenticated User.
device_id: The id of the authenticated device.
api_key_id: The id of the authenticated API key for which this token
was issued.
schedule_id: The id of the schedule for which the token was issued.
pipeline_run_id: The id of the pipeline run for which the token was
issued.
step_run_id: The id of the step run for which the token was
issued.
claims: The original token claims.
"""
user_id: UUID
device_id: Optional[UUID] = None
api_key_id: Optional[UUID] = None
schedule_id: Optional[UUID] = None
pipeline_run_id: Optional[UUID] = None
step_run_id: Optional[UUID] = None
claims: Dict[str, Any] = {}
@classmethod
def decode_token(
cls,
token: str,
verify: bool = True,
) -> "JWTToken":
"""Decodes a JWT access token.
Decodes a JWT access token and returns a `JWTToken` object with the
information retrieved from its subject claim.
Args:
token: The encoded JWT token.
verify: Whether to verify the signature of the token.
Returns:
The decoded JWT access token.
Raises:
CredentialsNotValid: If the token is invalid.
"""
config = server_config()
try:
claims_data = jwt.decode(
token,
config.jwt_secret_key,
algorithms=[config.jwt_token_algorithm],
audience=config.get_jwt_token_audience(),
issuer=config.get_jwt_token_issuer(),
verify=verify,
leeway=timedelta(seconds=config.jwt_token_leeway_seconds),
)
claims = cast(Dict[str, Any], claims_data)
except jwt.PyJWTError as e:
raise CredentialsNotValid(f"Invalid JWT token: {e}") from e
subject: str = claims.pop("sub", "")
if not subject:
raise CredentialsNotValid(
"Invalid JWT token: the subject claim is missing"
)
try:
user_id = UUID(subject)
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the subject claim is not a valid UUID"
)
device_id: Optional[UUID] = None
if "device_id" in claims:
try:
device_id = UUID(claims.pop("device_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the device_id claim is not a valid "
"UUID"
)
api_key_id: Optional[UUID] = None
if "api_key_id" in claims:
try:
api_key_id = UUID(claims.pop("api_key_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the api_key_id claim is not a valid "
"UUID"
)
schedule_id: Optional[UUID] = None
if "schedule_id" in claims:
try:
schedule_id = UUID(claims.pop("schedule_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the schedule_id claim is not a valid "
"UUID"
)
pipeline_run_id: Optional[UUID] = None
if "pipeline_run_id" in claims:
try:
pipeline_run_id = UUID(claims.pop("pipeline_run_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the pipeline_run_id claim is not a valid "
"UUID"
)
step_run_id: Optional[UUID] = None
if "step_run_id" in claims:
try:
step_run_id = UUID(claims.pop("step_run_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the step_run_id claim is not a valid "
"UUID"
)
return JWTToken(
user_id=user_id,
device_id=device_id,
api_key_id=api_key_id,
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
claims=claims,
)
def encode(self, expires: Optional[datetime] = None) -> str:
"""Creates a JWT access token.
Encodes, signs and returns a JWT access token.
Args:
expires: Datetime after which the token will expire. If not
provided, the JWT token will not be set to expire.
Returns:
The generated access token.
"""
config = server_config()
claims: Dict[str, Any] = self.claims.copy()
claims["sub"] = str(self.user_id)
claims["iss"] = config.get_jwt_token_issuer()
claims["aud"] = config.get_jwt_token_audience()
if expires:
claims["exp"] = expires
else:
claims.pop("exp", None)
if self.device_id:
claims["device_id"] = str(self.device_id)
if self.api_key_id:
claims["api_key_id"] = str(self.api_key_id)
if self.schedule_id:
claims["schedule_id"] = str(self.schedule_id)
if self.pipeline_run_id:
claims["pipeline_run_id"] = str(self.pipeline_run_id)
if self.step_run_id:
claims["step_run_id"] = str(self.step_run_id)
return jwt.encode(
claims,
config.jwt_secret_key,
algorithm=config.jwt_token_algorithm,
)
decode_token(token, verify=True)
classmethod
Decodes a JWT access token.
Decodes a JWT access token and returns a JWTToken
object with the
information retrieved from its subject claim.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token |
str |
The encoded JWT token. |
required |
verify |
bool |
Whether to verify the signature of the token. |
True |
Returns:
Type | Description |
---|---|
JWTToken |
The decoded JWT access token. |
Exceptions:
Type | Description |
---|---|
CredentialsNotValid |
If the token is invalid. |
Source code in zenml/zen_server/jwt.py
@classmethod
def decode_token(
cls,
token: str,
verify: bool = True,
) -> "JWTToken":
"""Decodes a JWT access token.
Decodes a JWT access token and returns a `JWTToken` object with the
information retrieved from its subject claim.
Args:
token: The encoded JWT token.
verify: Whether to verify the signature of the token.
Returns:
The decoded JWT access token.
Raises:
CredentialsNotValid: If the token is invalid.
"""
config = server_config()
try:
claims_data = jwt.decode(
token,
config.jwt_secret_key,
algorithms=[config.jwt_token_algorithm],
audience=config.get_jwt_token_audience(),
issuer=config.get_jwt_token_issuer(),
verify=verify,
leeway=timedelta(seconds=config.jwt_token_leeway_seconds),
)
claims = cast(Dict[str, Any], claims_data)
except jwt.PyJWTError as e:
raise CredentialsNotValid(f"Invalid JWT token: {e}") from e
subject: str = claims.pop("sub", "")
if not subject:
raise CredentialsNotValid(
"Invalid JWT token: the subject claim is missing"
)
try:
user_id = UUID(subject)
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the subject claim is not a valid UUID"
)
device_id: Optional[UUID] = None
if "device_id" in claims:
try:
device_id = UUID(claims.pop("device_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the device_id claim is not a valid "
"UUID"
)
api_key_id: Optional[UUID] = None
if "api_key_id" in claims:
try:
api_key_id = UUID(claims.pop("api_key_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the api_key_id claim is not a valid "
"UUID"
)
schedule_id: Optional[UUID] = None
if "schedule_id" in claims:
try:
schedule_id = UUID(claims.pop("schedule_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the schedule_id claim is not a valid "
"UUID"
)
pipeline_run_id: Optional[UUID] = None
if "pipeline_run_id" in claims:
try:
pipeline_run_id = UUID(claims.pop("pipeline_run_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the pipeline_run_id claim is not a valid "
"UUID"
)
step_run_id: Optional[UUID] = None
if "step_run_id" in claims:
try:
step_run_id = UUID(claims.pop("step_run_id"))
except ValueError:
raise CredentialsNotValid(
"Invalid JWT token: the step_run_id claim is not a valid "
"UUID"
)
return JWTToken(
user_id=user_id,
device_id=device_id,
api_key_id=api_key_id,
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
claims=claims,
)
encode(self, expires=None)
Creates a JWT access token.
Encodes, signs and returns a JWT access token.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expires |
Optional[datetime.datetime] |
Datetime after which the token will expire. If not provided, the JWT token will not be set to expire. |
None |
Returns:
Type | Description |
---|---|
str |
The generated access token. |
Source code in zenml/zen_server/jwt.py
def encode(self, expires: Optional[datetime] = None) -> str:
"""Creates a JWT access token.
Encodes, signs and returns a JWT access token.
Args:
expires: Datetime after which the token will expire. If not
provided, the JWT token will not be set to expire.
Returns:
The generated access token.
"""
config = server_config()
claims: Dict[str, Any] = self.claims.copy()
claims["sub"] = str(self.user_id)
claims["iss"] = config.get_jwt_token_issuer()
claims["aud"] = config.get_jwt_token_audience()
if expires:
claims["exp"] = expires
else:
claims.pop("exp", None)
if self.device_id:
claims["device_id"] = str(self.device_id)
if self.api_key_id:
claims["api_key_id"] = str(self.api_key_id)
if self.schedule_id:
claims["schedule_id"] = str(self.schedule_id)
if self.pipeline_run_id:
claims["pipeline_run_id"] = str(self.pipeline_run_id)
if self.step_run_id:
claims["step_run_id"] = str(self.step_run_id)
return jwt.encode(
claims,
config.jwt_secret_key,
algorithm=config.jwt_token_algorithm,
)
rate_limit
Rate limiting for the ZenML Server.
RequestLimiter
Simple in-memory rate limiter.
Source code in zenml/zen_server/rate_limit.py
class RequestLimiter:
"""Simple in-memory rate limiter."""
def __init__(
self,
day_limit: Optional[int] = None,
minute_limit: Optional[int] = None,
):
"""Initializes the limiter.
Args:
day_limit: The number of requests allowed per day.
minute_limit: The number of requests allowed per minute.
Raises:
ValueError: If both day_limit and minute_limit are None.
"""
self.limiting_enabled = server_config().rate_limit_enabled
if not self.limiting_enabled:
return
if day_limit is None and minute_limit is None:
raise ValueError("Pass either day or minuter limits, or both.")
self.day_limit = day_limit
self.minute_limit = minute_limit
self.limiter: Dict[str, List[float]] = defaultdict(list)
def hit_limiter(self, request: Request) -> None:
"""Increase the number of hits in the limiter.
Args:
request: Request object.
Raises:
HTTPException: If the request limit is exceeded.
"""
if not self.limiting_enabled:
return
from fastapi import HTTPException
requester = self._get_ipaddr(request)
now = time.time()
minute_ago = now - 60
day_ago = now - 60 * 60 * 24
self.limiter[requester].append(now)
from bisect import bisect_left
# remove failures older than a day
older_index = bisect_left(self.limiter[requester], day_ago)
self.limiter[requester] = self.limiter[requester][older_index:]
if self.day_limit and len(self.limiter[requester]) > self.day_limit:
raise HTTPException(
status_code=429, detail="Daily request limit exceeded."
)
minute_requests = len(
[
limiter_hit
for limiter_hit in self.limiter[requester][::-1]
if limiter_hit >= minute_ago
]
)
if self.minute_limit and minute_requests > self.minute_limit:
raise HTTPException(
status_code=429, detail="Minute request limit exceeded."
)
def reset_limiter(self, request: Request) -> None:
"""Resets the limiter on successful request.
Args:
request: Request object.
"""
if self.limiting_enabled:
requester = self._get_ipaddr(request)
if requester in self.limiter:
del self.limiter[requester]
def _get_ipaddr(self, request: Request) -> str:
"""Returns the IP address for the current request.
Based on the X-Forwarded-For headers or client information.
Args:
request: The request object.
Returns:
The ip address for the current request (or 127.0.0.1 if none found).
"""
if "X_FORWARDED_FOR" in request.headers:
return request.headers["X_FORWARDED_FOR"]
else:
if not request.client or not request.client.host:
return "127.0.0.1"
return request.client.host
@contextmanager
def limit_failed_requests(
self, request: Request
) -> Generator[None, Any, Any]:
"""Limits the number of failed requests.
Args:
request: Request object.
Yields:
None
"""
self.hit_limiter(request)
yield
# if request was successful - reset limiter
self.reset_limiter(request)
__init__(self, day_limit=None, minute_limit=None)
special
Initializes the limiter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
day_limit |
Optional[int] |
The number of requests allowed per day. |
None |
minute_limit |
Optional[int] |
The number of requests allowed per minute. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If both day_limit and minute_limit are None. |
Source code in zenml/zen_server/rate_limit.py
def __init__(
self,
day_limit: Optional[int] = None,
minute_limit: Optional[int] = None,
):
"""Initializes the limiter.
Args:
day_limit: The number of requests allowed per day.
minute_limit: The number of requests allowed per minute.
Raises:
ValueError: If both day_limit and minute_limit are None.
"""
self.limiting_enabled = server_config().rate_limit_enabled
if not self.limiting_enabled:
return
if day_limit is None and minute_limit is None:
raise ValueError("Pass either day or minuter limits, or both.")
self.day_limit = day_limit
self.minute_limit = minute_limit
self.limiter: Dict[str, List[float]] = defaultdict(list)
hit_limiter(self, request)
Increase the number of hits in the limiter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the request limit is exceeded. |
Source code in zenml/zen_server/rate_limit.py
def hit_limiter(self, request: Request) -> None:
"""Increase the number of hits in the limiter.
Args:
request: Request object.
Raises:
HTTPException: If the request limit is exceeded.
"""
if not self.limiting_enabled:
return
from fastapi import HTTPException
requester = self._get_ipaddr(request)
now = time.time()
minute_ago = now - 60
day_ago = now - 60 * 60 * 24
self.limiter[requester].append(now)
from bisect import bisect_left
# remove failures older than a day
older_index = bisect_left(self.limiter[requester], day_ago)
self.limiter[requester] = self.limiter[requester][older_index:]
if self.day_limit and len(self.limiter[requester]) > self.day_limit:
raise HTTPException(
status_code=429, detail="Daily request limit exceeded."
)
minute_requests = len(
[
limiter_hit
for limiter_hit in self.limiter[requester][::-1]
if limiter_hit >= minute_ago
]
)
if self.minute_limit and minute_requests > self.minute_limit:
raise HTTPException(
status_code=429, detail="Minute request limit exceeded."
)
limit_failed_requests(self, request)
Limits the number of failed requests.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
Yields:
Type | Description |
---|---|
Generator[NoneType, Any, Any] |
None |
Source code in zenml/zen_server/rate_limit.py
@contextmanager
def limit_failed_requests(
self, request: Request
) -> Generator[None, Any, Any]:
"""Limits the number of failed requests.
Args:
request: Request object.
Yields:
None
"""
self.hit_limiter(request)
yield
# if request was successful - reset limiter
self.reset_limiter(request)
reset_limiter(self, request)
Resets the limiter on successful request.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
Source code in zenml/zen_server/rate_limit.py
def reset_limiter(self, request: Request) -> None:
"""Resets the limiter on successful request.
Args:
request: Request object.
"""
if self.limiting_enabled:
requester = self._get_ipaddr(request)
if requester in self.limiter:
del self.limiter[requester]
rate_limit_requests(day_limit=None, minute_limit=None)
Decorator to handle exceptions in the API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
day_limit |
Optional[int] |
Number of requests allowed per day. |
None |
minute_limit |
Optional[int] |
Number of requests allowed per minute. |
None |
Returns:
Type | Description |
---|---|
Callable[..., Any] |
Decorated function. |
Source code in zenml/zen_server/rate_limit.py
def rate_limit_requests(
day_limit: Optional[int] = None,
minute_limit: Optional[int] = None,
) -> Callable[..., Any]:
"""Decorator to handle exceptions in the API.
Args:
day_limit: Number of requests allowed per day.
minute_limit: Number of requests allowed per minute.
Returns:
Decorated function.
"""
limiter = RequestLimiter(day_limit=day_limit, minute_limit=minute_limit)
def decorator(func: F) -> F:
request_arg, request_kwarg = None, None
parameters = inspect.signature(func).parameters
for arg_num, arg_name in enumerate(parameters):
if parameters[arg_name].annotation == Request:
request_arg = arg_num
request_kwarg = arg_name
break
if request_arg is None or request_kwarg is None:
raise ValueError(
"Rate limiting APIs must have argument of `Request` type."
)
@wraps(func)
def decorated(
*args: Any,
**kwargs: Any,
) -> Any:
if request_kwarg in kwargs:
request = kwargs[request_kwarg]
else:
request = args[request_arg]
with limiter.limit_failed_requests(request):
return func(*args, **kwargs)
return cast(F, decorated)
return decorator
rbac
special
RBAC definitions.
endpoint_utils
High-level helper functions to write endpoints with RBAC.
verify_permissions_and_batch_create_entity(batch, resource_type, create_method)
Verify permissions and create a batch of entities if authorized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
List[~AnyRequest] |
The batch to create. |
required |
resource_type |
ResourceType |
The resource type of the entities to create. |
required |
create_method |
Callable[[List[~AnyRequest]], List[~AnyResponse]] |
The method to create the entities. |
required |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the request model has a different owner then the currently authenticated user. |
RuntimeError |
If the resource type is usage-tracked. |
Returns:
Type | Description |
---|---|
List[~AnyResponse] |
The created entities. |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_batch_create_entity(
batch: List[AnyRequest],
resource_type: ResourceType,
create_method: Callable[[List[AnyRequest]], List[AnyResponse]],
) -> List[AnyResponse]:
"""Verify permissions and create a batch of entities if authorized.
Args:
batch: The batch to create.
resource_type: The resource type of the entities to create.
create_method: The method to create the entities.
Raises:
IllegalOperationError: If the request model has a different owner then
the currently authenticated user.
RuntimeError: If the resource type is usage-tracked.
Returns:
The created entities.
"""
auth_context = get_auth_context()
assert auth_context
for request_model in batch:
if isinstance(request_model, UserScopedRequest):
if request_model.user != auth_context.user.id:
raise IllegalOperationError(
f"Not allowed to create resource '{resource_type}' for a "
"different user."
)
verify_permission(resource_type=resource_type, action=Action.CREATE)
if resource_type in REPORTABLE_RESOURCES:
raise RuntimeError(
"Batch requests are currently not possible with usage-tracked features."
)
created = create_method(batch)
return created
verify_permissions_and_create_entity(request_model, resource_type, create_method)
Verify permissions and create the entity if authorized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request_model |
~AnyRequest |
The entity request model. |
required |
resource_type |
ResourceType |
The resource type of the entity to create. |
required |
create_method |
Callable[[~AnyRequest], ~AnyResponse] |
The method to create the entity. |
required |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the request model has a different owner then the currently authenticated user. |
Returns:
Type | Description |
---|---|
~AnyResponse |
A model of the created entity. |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_create_entity(
request_model: AnyRequest,
resource_type: ResourceType,
create_method: Callable[[AnyRequest], AnyResponse],
) -> AnyResponse:
"""Verify permissions and create the entity if authorized.
Args:
request_model: The entity request model.
resource_type: The resource type of the entity to create.
create_method: The method to create the entity.
Raises:
IllegalOperationError: If the request model has a different owner then
the currently authenticated user.
Returns:
A model of the created entity.
"""
if isinstance(request_model, UserScopedRequest):
auth_context = get_auth_context()
assert auth_context
if request_model.user != auth_context.user.id:
raise IllegalOperationError(
f"Not allowed to create resource '{resource_type}' for a "
"different user."
)
verify_permission(resource_type=resource_type, action=Action.CREATE)
needs_usage_increment = (
resource_type in REPORTABLE_RESOURCES
and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING
)
if needs_usage_increment:
check_entitlement(resource_type)
created = create_method(request_model)
if needs_usage_increment:
report_usage(resource_type, resource_id=created.id)
return created
verify_permissions_and_delete_entity(id, get_method, delete_method)
Verify permissions and delete an entity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
~UUIDOrStr |
The ID of the entity to delete. |
required |
get_method |
Callable[[~UUIDOrStr, bool], ~AnyResponse] |
The method to fetch the entity. |
required |
delete_method |
Callable[[~UUIDOrStr], NoneType] |
The method to delete the entity. |
required |
Returns:
Type | Description |
---|---|
~AnyResponse |
The deleted entity. |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_delete_entity(
id: UUIDOrStr,
get_method: Callable[[UUIDOrStr, bool], AnyResponse],
delete_method: Callable[[UUIDOrStr], None],
) -> AnyResponse:
"""Verify permissions and delete an entity.
Args:
id: The ID of the entity to delete.
get_method: The method to fetch the entity.
delete_method: The method to delete the entity.
Returns:
The deleted entity.
"""
# We don't need the hydrated version here
model = get_method(id, False)
verify_permission_for_model(model, action=Action.DELETE)
delete_method(model.id)
return model
verify_permissions_and_get_entity(id, get_method, **get_method_kwargs)
Verify permissions and fetch an entity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
~UUIDOrStr |
The ID of the entity to fetch. |
required |
get_method |
Callable[[~UUIDOrStr], ~AnyResponse] |
The method to fetch the entity. |
required |
get_method_kwargs |
Any |
Keyword arguments to pass to the get method. |
{} |
Returns:
Type | Description |
---|---|
~AnyResponse |
A model of the fetched entity. |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_get_entity(
id: UUIDOrStr,
get_method: Callable[[UUIDOrStr], AnyResponse],
**get_method_kwargs: Any,
) -> AnyResponse:
"""Verify permissions and fetch an entity.
Args:
id: The ID of the entity to fetch.
get_method: The method to fetch the entity.
get_method_kwargs: Keyword arguments to pass to the get method.
Returns:
A model of the fetched entity.
"""
model = get_method(id, **get_method_kwargs)
verify_permission_for_model(model, action=Action.READ)
return dehydrate_response_model(model)
verify_permissions_and_list_entities(filter_model, resource_type, list_method, **list_method_kwargs)
Verify permissions and list entities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
~AnyFilter |
The entity filter model. |
required |
resource_type |
ResourceType |
The resource type of the entities to list. |
required |
list_method |
Callable[[~AnyFilter], zenml.models.v2.base.page.Page[~AnyResponse]] |
The method to list the entities. |
required |
list_method_kwargs |
Any |
Keyword arguments to pass to the list method. |
{} |
Returns:
Type | Description |
---|---|
Page[~AnyResponse] |
A page of entity models. |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_list_entities(
filter_model: AnyFilter,
resource_type: ResourceType,
list_method: Callable[[AnyFilter], Page[AnyResponse]],
**list_method_kwargs: Any,
) -> Page[AnyResponse]:
"""Verify permissions and list entities.
Args:
filter_model: The entity filter model.
resource_type: The resource type of the entities to list.
list_method: The method to list the entities.
list_method_kwargs: Keyword arguments to pass to the list method.
Returns:
A page of entity models.
"""
auth_context = get_auth_context()
assert auth_context
allowed_ids = get_allowed_resource_ids(resource_type=resource_type)
filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id, id=allowed_ids
)
page = list_method(filter_model, **list_method_kwargs)
return dehydrate_page(page)
verify_permissions_and_prune_entities(resource_type, prune_method, **kwargs)
Verify permissions and prune entities of certain type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
ResourceType |
The resource type of the entities to prune. |
required |
prune_method |
Callable[..., NoneType] |
The method to prune the entities. |
required |
kwargs |
Any |
Keyword arguments to pass to the prune method. |
{} |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_prune_entities(
resource_type: ResourceType,
prune_method: Callable[..., None],
**kwargs: Any,
) -> None:
"""Verify permissions and prune entities of certain type.
Args:
resource_type: The resource type of the entities to prune.
prune_method: The method to prune the entities.
kwargs: Keyword arguments to pass to the prune method.
"""
verify_permission(resource_type=resource_type, action=Action.PRUNE)
prune_method(**kwargs)
verify_permissions_and_update_entity(id, update_model, get_method, update_method)
Verify permissions and update an entity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
~UUIDOrStr |
The ID of the entity to update. |
required |
update_model |
~AnyUpdate |
The entity update model. |
required |
get_method |
Callable[[~UUIDOrStr, bool], ~AnyResponse] |
The method to fetch the entity. |
required |
update_method |
Callable[[~UUIDOrStr, ~AnyUpdate], ~AnyResponse] |
The method to update the entity. |
required |
Returns:
Type | Description |
---|---|
~AnyResponse |
A model of the updated entity. |
Source code in zenml/zen_server/rbac/endpoint_utils.py
def verify_permissions_and_update_entity(
id: UUIDOrStr,
update_model: AnyUpdate,
get_method: Callable[[UUIDOrStr, bool], AnyResponse],
update_method: Callable[[UUIDOrStr, AnyUpdate], AnyResponse],
) -> AnyResponse:
"""Verify permissions and update an entity.
Args:
id: The ID of the entity to update.
update_model: The entity update model.
get_method: The method to fetch the entity.
update_method: The method to update the entity.
Returns:
A model of the updated entity.
"""
# We don't need the hydrated version here
model = get_method(id, False)
verify_permission_for_model(model, action=Action.UPDATE)
updated_model = update_method(model.id, update_model)
return dehydrate_response_model(updated_model)
models
RBAC model classes.
Action (StrEnum)
RBAC actions.
Source code in zenml/zen_server/rbac/models.py
class Action(StrEnum):
"""RBAC actions."""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
READ_SECRET_VALUE = "read_secret_value"
PRUNE = "prune"
# Service connectors
CLIENT = "client"
# Models
PROMOTE = "promote"
# Secrets
BACKUP_RESTORE = "backup_restore"
SHARE = "share"
Resource (BaseModel)
RBAC resource model.
Source code in zenml/zen_server/rbac/models.py
class Resource(BaseModel):
"""RBAC resource model."""
type: str
id: Optional[UUID] = None
def __str__(self) -> str:
"""Convert to a string.
Returns:
Resource string representation.
"""
representation = self.type
if self.id:
representation += f"/{self.id}"
return representation
model_config = ConfigDict(frozen=True)
__str__(self)
special
Convert to a string.
Returns:
Type | Description |
---|---|
str |
Resource string representation. |
Source code in zenml/zen_server/rbac/models.py
def __str__(self) -> str:
"""Convert to a string.
Returns:
Resource string representation.
"""
representation = self.type
if self.id:
representation += f"/{self.id}"
return representation
ResourceType (StrEnum)
Resource types of the server API.
Source code in zenml/zen_server/rbac/models.py
class ResourceType(StrEnum):
"""Resource types of the server API."""
ACTION = "action"
ARTIFACT = "artifact"
ARTIFACT_VERSION = "artifact_version"
CODE_REPOSITORY = "code_repository"
EVENT_SOURCE = "event_source"
FLAVOR = "flavor"
MODEL = "model"
MODEL_VERSION = "model_version"
PIPELINE = "pipeline"
PIPELINE_RUN = "pipeline_run"
PIPELINE_DEPLOYMENT = "pipeline_deployment"
PIPELINE_BUILD = "pipeline_build"
RUN_TEMPLATE = "run_template"
SERVICE = "service"
RUN_METADATA = "run_metadata"
SECRET = "secret"
SERVICE_ACCOUNT = "service_account"
SERVICE_CONNECTOR = "service_connector"
STACK = "stack"
STACK_COMPONENT = "stack_component"
TAG = "tag"
TRIGGER = "trigger"
TRIGGER_EXECUTION = "trigger_execution"
# Deactivated for now
# USER = "user"
# WORKSPACE = "workspace"
rbac_interface
RBAC interface definition.
RBACInterface (ABC)
RBAC interface definition.
Source code in zenml/zen_server/rbac/rbac_interface.py
class RBACInterface(ABC):
"""RBAC interface definition."""
@abstractmethod
def check_permissions(
self, user: "UserResponse", resources: Set[Resource], action: Action
) -> Dict[Resource, bool]:
"""Checks if a user has permissions to perform an action on resources.
Args:
user: User which wants to access a resource.
resources: The resources the user wants to access.
action: The action that the user wants to perform on the resources.
Returns:
A dictionary mapping resources to a boolean which indicates whether
the user has permissions to perform the action on that resource.
"""
@abstractmethod
def list_allowed_resource_ids(
self, user: "UserResponse", resource: Resource, action: Action
) -> Tuple[bool, List[str]]:
"""Lists all resource IDs of a resource type that a user can access.
Args:
user: User which wants to access a resource.
resource: The resource the user wants to access.
action: The action that the user wants to perform on the resource.
Returns:
A tuple (full_resource_access, resource_ids).
`full_resource_access` will be `True` if the user can perform the
given action on any instance of the given resource type, `False`
otherwise. If `full_resource_access` is `False`, `resource_ids`
will contain the list of instance IDs that the user can perform
the action on.
"""
@abstractmethod
def update_resource_membership(
self, user: "UserResponse", resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.
Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
check_permissions(self, user, resources, action)
Checks if a user has permissions to perform an action on resources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User which wants to access a resource. |
required |
resources |
Set[zenml.zen_server.rbac.models.Resource] |
The resources the user wants to access. |
required |
action |
Action |
The action that the user wants to perform on the resources. |
required |
Returns:
Type | Description |
---|---|
Dict[zenml.zen_server.rbac.models.Resource, bool] |
A dictionary mapping resources to a boolean which indicates whether the user has permissions to perform the action on that resource. |
Source code in zenml/zen_server/rbac/rbac_interface.py
@abstractmethod
def check_permissions(
self, user: "UserResponse", resources: Set[Resource], action: Action
) -> Dict[Resource, bool]:
"""Checks if a user has permissions to perform an action on resources.
Args:
user: User which wants to access a resource.
resources: The resources the user wants to access.
action: The action that the user wants to perform on the resources.
Returns:
A dictionary mapping resources to a boolean which indicates whether
the user has permissions to perform the action on that resource.
"""
list_allowed_resource_ids(self, user, resource, action)
Lists all resource IDs of a resource type that a user can access.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User which wants to access a resource. |
required |
resource |
Resource |
The resource the user wants to access. |
required |
action |
Action |
The action that the user wants to perform on the resource. |
required |
Returns:
Type | Description |
---|---|
Tuple[bool, List[str]] |
A tuple (full_resource_access, resource_ids).
|
Source code in zenml/zen_server/rbac/rbac_interface.py
@abstractmethod
def list_allowed_resource_ids(
self, user: "UserResponse", resource: Resource, action: Action
) -> Tuple[bool, List[str]]:
"""Lists all resource IDs of a resource type that a user can access.
Args:
user: User which wants to access a resource.
resource: The resource the user wants to access.
action: The action that the user wants to perform on the resource.
Returns:
A tuple (full_resource_access, resource_ids).
`full_resource_access` will be `True` if the user can perform the
given action on any instance of the given resource type, `False`
otherwise. If `full_resource_access` is `False`, `resource_ids`
will contain the list of instance IDs that the user can perform
the action on.
"""
update_resource_membership(self, user, resource, actions)
Update the resource membership of a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User for which the resource membership should be updated. |
required |
resource |
Resource |
The resource. |
required |
actions |
List[zenml.zen_server.rbac.models.Action] |
The actions that the user should be able to perform on the resource. |
required |
Source code in zenml/zen_server/rbac/rbac_interface.py
@abstractmethod
def update_resource_membership(
self, user: "UserResponse", resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.
Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
rbac_sql_zen_store
RBAC SQL Zen Store implementation.
RBACSqlZenStore (SqlZenStore)
Wrapper around the SQLZenStore that implements RBAC functionality.
Source code in zenml/zen_server/rbac/rbac_sql_zen_store.py
class RBACSqlZenStore(SqlZenStore):
"""Wrapper around the SQLZenStore that implements RBAC functionality."""
def _get_or_create_model(
self, model_request: ModelRequest
) -> Tuple[bool, ModelResponse]:
"""Get or create a model.
Args:
model_request: The model request.
# noqa: DAR401
Raises:
Exception: If the user is not allowed to create a model.
Returns:
A boolean whether the model was created or not, and the model.
"""
allow_model_creation = True
error = None
try:
verify_permission(
resource_type=ResourceType.MODEL, action=Action.CREATE
)
check_entitlement(resource_type=ResourceType.MODEL)
except Exception as e:
allow_model_creation = False
error = e
if allow_model_creation:
created, model_response = super()._get_or_create_model(
model_request
)
else:
try:
model_response = self.get_model(model_request.name)
created = False
except KeyError:
# The model does not exist. We now raise the error that
# explains why the model could not be created, instead of just
# the KeyError that it doesn't exist
assert error
raise error from None
if created:
report_usage(
resource_type=ResourceType.MODEL, resource_id=model_response.id
)
else:
verify_permission_for_model(model_response, action=Action.READ)
return created, model_response
def _get_model_version(
self,
model_id: UUID,
version_name: Optional[str] = None,
producer_run_id: Optional[UUID] = None,
) -> ModelVersionResponse:
"""Get a model version.
Args:
model_id: The ID of the model.
version_name: The name of the model version.
producer_run_id: The ID of the producer pipeline run. If this is
set, only numeric versions created as part of the pipeline run
will be returned.
Returns:
The model version.
"""
model_version = super()._get_model_version(
model_id=model_id,
version_name=version_name,
producer_run_id=producer_run_id,
)
verify_permission_for_model(model_version, action=Action.READ)
return model_version
def _get_or_create_model_version(
self,
model_version_request: ModelVersionRequest,
producer_run_id: Optional[UUID] = None,
) -> Tuple[bool, ModelVersionResponse]:
"""Get or create a model version.
Args:
model_version_request: The model version request.
producer_run_id: ID of the producer pipeline run.
# noqa: DAR401
Raises:
Exception: If the authenticated user is not allowed to
create a model version.
Returns:
A boolean whether the model version was created or not, and the
model version.
"""
allow_creation = True
error = None
try:
verify_permission(
resource_type=ResourceType.MODEL_VERSION, action=Action.CREATE
)
except Exception as e:
allow_creation = False
error = e
if allow_creation:
created, model_version_response = (
super()._get_or_create_model_version(
model_version_request, producer_run_id=producer_run_id
)
)
else:
try:
model_version_response = self._get_model_version(
model_id=model_version_request.model,
version_name=model_version_request.name,
producer_run_id=producer_run_id,
)
created = False
except KeyError:
# The model version does not exist. We now raise the error that
# explains why the version could not be created, instead of just
# the KeyError that it doesn't exist
assert error
raise error from None
return created, model_version_response
model_post_init(/, self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Source code in zenml/zen_server/rbac/rbac_sql_zen_store.py
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, context)
original_model_post_init(self, context)
utils
RBAC utility functions.
batch_verify_permissions(resources, action)
Batch permission verification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resources |
Set[zenml.zen_server.rbac.models.Resource] |
The resources the user wants to perform the action on. |
required |
action |
Action |
The action the user wants to perform. |
required |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the user is not allowed to perform the action. |
RuntimeError |
If the permission verification failed unexpectedly. |
Source code in zenml/zen_server/rbac/utils.py
def batch_verify_permissions(
resources: Set[Resource],
action: Action,
) -> None:
"""Batch permission verification.
Args:
resources: The resources the user wants to perform the action on.
action: The action the user wants to perform.
Raises:
IllegalOperationError: If the user is not allowed to perform the action.
RuntimeError: If the permission verification failed unexpectedly.
"""
if not server_config().rbac_enabled:
return
auth_context = get_auth_context()
assert auth_context
permissions = rbac().check_permissions(
user=auth_context.user, resources=resources, action=action
)
for resource in resources:
if resource not in permissions:
# This should never happen if the RBAC implementation is working
# correctly
raise RuntimeError(
f"Failed to verify permissions to {action.upper()} resource "
f"'{resource}'."
)
if not permissions[resource]:
raise IllegalOperationError(
message=f"Insufficient permissions to {action.upper()} "
f"resource '{resource}'.",
)
batch_verify_permissions_for_models(models, action)
Batch permission verification for models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models |
Sequence[~AnyResponse] |
The models the user wants to perform the action on. |
required |
action |
Action |
The action the user wants to perform. |
required |
Source code in zenml/zen_server/rbac/utils.py
def batch_verify_permissions_for_models(
models: Sequence[AnyResponse],
action: Action,
) -> None:
"""Batch permission verification for models.
Args:
models: The models the user wants to perform the action on.
action: The action the user wants to perform.
"""
if not server_config().rbac_enabled:
return
resources = set()
for model in models:
if is_owned_by_authenticated_user(model):
# The model owner always has permissions
continue
permission_model = get_surrogate_permission_model_for_model(
model, action=action
)
if resource := get_resource_for_model(permission_model):
resources.add(resource)
batch_verify_permissions(resources=resources, action=action)
dehydrate_page(page)
Dehydrate all items of a page.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
page |
Page[~AnyResponse] |
The page to dehydrate. |
required |
Returns:
Type | Description |
---|---|
Page[~AnyResponse] |
The page with (potentially) dehydrated items. |
Source code in zenml/zen_server/rbac/utils.py
def dehydrate_page(page: Page[AnyResponse]) -> Page[AnyResponse]:
"""Dehydrate all items of a page.
Args:
page: The page to dehydrate.
Returns:
The page with (potentially) dehydrated items.
"""
if not server_config().rbac_enabled:
return page
auth_context = get_auth_context()
assert auth_context
resource_list = [get_subresources_for_model(item) for item in page.items]
resources = set.union(*resource_list) if resource_list else set()
permissions = rbac().check_permissions(
user=auth_context.user, resources=resources, action=Action.READ
)
new_items = [
dehydrate_response_model(item, permissions=permissions)
for item in page.items
]
return page.model_copy(update={"items": new_items})
dehydrate_response_model(model, permissions=None)
Dehydrate a model if necessary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyModel |
The model to dehydrate. |
required |
permissions |
Optional[Dict[zenml.zen_server.rbac.models.Resource, bool]] |
Prefetched permissions that will be used to check whether sub-models will be included in the model or not. If a sub-model refers to a resource which is not included in this dictionary, the permissions will be checked with the RBAC component. |
None |
Returns:
Type | Description |
---|---|
~AnyModel |
The (potentially) dehydrated model. |
Source code in zenml/zen_server/rbac/utils.py
def dehydrate_response_model(
model: AnyModel, permissions: Optional[Dict[Resource, bool]] = None
) -> AnyModel:
"""Dehydrate a model if necessary.
Args:
model: The model to dehydrate.
permissions: Prefetched permissions that will be used to check whether
sub-models will be included in the model or not. If a sub-model
refers to a resource which is not included in this dictionary, the
permissions will be checked with the RBAC component.
Returns:
The (potentially) dehydrated model.
"""
if not server_config().rbac_enabled:
return model
if not permissions:
auth_context = get_auth_context()
assert auth_context
resources = get_subresources_for_model(model)
permissions = rbac().check_permissions(
user=auth_context.user, resources=resources, action=Action.READ
)
dehydrated_values = {}
# See `get_subresources_for_model(...)` for a detailed explanation why we
# need to use `model.__iter__()` here
for key, value in model.__iter__():
dehydrated_values[key] = _dehydrate_value(
value, permissions=permissions
)
return type(model).model_validate(dehydrated_values)
get_allowed_resource_ids(resource_type, action=<Action.READ: 'read'>)
Get all resource IDs of a resource type that a user can access.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
str |
The resource type. |
required |
action |
Action |
The action the user wants to perform on the resource. |
<Action.READ: 'read'> |
Returns:
Type | Description |
---|---|
Optional[Set[uuid.UUID]] |
A list of resource IDs or |
Source code in zenml/zen_server/rbac/utils.py
def get_allowed_resource_ids(
resource_type: str,
action: Action = Action.READ,
) -> Optional[Set[UUID]]:
"""Get all resource IDs of a resource type that a user can access.
Args:
resource_type: The resource type.
action: The action the user wants to perform on the resource.
Returns:
A list of resource IDs or `None` if the user has full access to the
all instances of the resource.
"""
if not server_config().rbac_enabled:
return None
auth_context = get_auth_context()
assert auth_context
(
has_full_resource_access,
allowed_ids,
) = rbac().list_allowed_resource_ids(
user=auth_context.user,
resource=Resource(type=resource_type),
action=action,
)
if has_full_resource_access:
return None
return {UUID(id) for id in allowed_ids}
get_permission_denied_model(model)
Get a model to return in case of missing read permissions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The original model. |
required |
Returns:
Type | Description |
---|---|
~AnyResponse |
The permission denied model. |
Source code in zenml/zen_server/rbac/utils.py
def get_permission_denied_model(model: AnyResponse) -> AnyResponse:
"""Get a model to return in case of missing read permissions.
Args:
model: The original model.
Returns:
The permission denied model.
"""
return model.model_copy(
update={
"body": None,
"metadata": None,
"resources": None,
"permission_denied": True,
}
)
get_resource_for_model(model)
Get the resource associated with a model object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The model for which to get the resource. |
required |
Returns:
Type | Description |
---|---|
Optional[zenml.zen_server.rbac.models.Resource] |
The resource associated with the model, or |
Source code in zenml/zen_server/rbac/utils.py
def get_resource_for_model(model: AnyResponse) -> Optional[Resource]:
"""Get the resource associated with a model object.
Args:
model: The model for which to get the resource.
Returns:
The resource associated with the model, or `None` if the model
is not associated with any resource type.
"""
resource_type = get_resource_type_for_model(model)
if not resource_type:
# This model is not tied to any RBAC resource type
return None
return Resource(type=resource_type, id=model.id)
get_resource_type_for_model(model)
Get the resource type associated with a model object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The model for which to get the resource type. |
required |
Returns:
Type | Description |
---|---|
Optional[zenml.zen_server.rbac.models.ResourceType] |
The resource type associated with the model, or |
Source code in zenml/zen_server/rbac/utils.py
def get_resource_type_for_model(
model: AnyResponse,
) -> Optional[ResourceType]:
"""Get the resource type associated with a model object.
Args:
model: The model for which to get the resource type.
Returns:
The resource type associated with the model, or `None` if the model
is not associated with any resource type.
"""
from zenml.models import (
ActionResponse,
ArtifactResponse,
ArtifactVersionResponse,
CodeRepositoryResponse,
ComponentResponse,
EventSourceResponse,
FlavorResponse,
ModelResponse,
ModelVersionResponse,
PipelineBuildResponse,
PipelineDeploymentResponse,
PipelineResponse,
PipelineRunResponse,
RunTemplateResponse,
SecretResponse,
ServiceAccountResponse,
ServiceConnectorResponse,
ServiceResponse,
StackResponse,
TagResponse,
TriggerExecutionResponse,
TriggerResponse,
)
mapping: Dict[
Any,
ResourceType,
] = {
ActionResponse: ResourceType.ACTION,
EventSourceResponse: ResourceType.EVENT_SOURCE,
FlavorResponse: ResourceType.FLAVOR,
ServiceConnectorResponse: ResourceType.SERVICE_CONNECTOR,
ComponentResponse: ResourceType.STACK_COMPONENT,
StackResponse: ResourceType.STACK,
PipelineResponse: ResourceType.PIPELINE,
CodeRepositoryResponse: ResourceType.CODE_REPOSITORY,
SecretResponse: ResourceType.SECRET,
ModelResponse: ResourceType.MODEL,
ModelVersionResponse: ResourceType.MODEL_VERSION,
ArtifactResponse: ResourceType.ARTIFACT,
ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION,
# WorkspaceResponse: ResourceType.WORKSPACE,
# UserResponse: ResourceType.USER,
PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT,
PipelineBuildResponse: ResourceType.PIPELINE_BUILD,
PipelineRunResponse: ResourceType.PIPELINE_RUN,
RunTemplateResponse: ResourceType.RUN_TEMPLATE,
TagResponse: ResourceType.TAG,
TriggerResponse: ResourceType.TRIGGER,
TriggerExecutionResponse: ResourceType.TRIGGER_EXECUTION,
ServiceAccountResponse: ResourceType.SERVICE_ACCOUNT,
ServiceResponse: ResourceType.SERVICE,
}
return mapping.get(type(model))
get_schema_for_resource_type(resource_type)
Get the database schema for a resource type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
ResourceType |
The resource type for which to get the database schema. |
required |
Returns:
Type | Description |
---|---|
Type[BaseSchema] |
The database schema. |
Source code in zenml/zen_server/rbac/utils.py
def get_schema_for_resource_type(
resource_type: ResourceType,
) -> Type["BaseSchema"]:
"""Get the database schema for a resource type.
Args:
resource_type: The resource type for which to get the database schema.
Returns:
The database schema.
"""
from zenml.zen_stores.schemas import (
ActionSchema,
ArtifactSchema,
ArtifactVersionSchema,
CodeRepositorySchema,
EventSourceSchema,
FlavorSchema,
ModelSchema,
ModelVersionSchema,
PipelineBuildSchema,
PipelineDeploymentSchema,
PipelineRunSchema,
PipelineSchema,
RunMetadataSchema,
RunTemplateSchema,
SecretSchema,
ServiceConnectorSchema,
ServiceSchema,
StackComponentSchema,
StackSchema,
TagSchema,
TriggerExecutionSchema,
TriggerSchema,
UserSchema,
)
mapping: Dict[ResourceType, Type["BaseSchema"]] = {
ResourceType.STACK: StackSchema,
ResourceType.FLAVOR: FlavorSchema,
ResourceType.STACK_COMPONENT: StackComponentSchema,
ResourceType.PIPELINE: PipelineSchema,
ResourceType.CODE_REPOSITORY: CodeRepositorySchema,
ResourceType.MODEL: ModelSchema,
ResourceType.MODEL_VERSION: ModelVersionSchema,
ResourceType.SERVICE_CONNECTOR: ServiceConnectorSchema,
ResourceType.ARTIFACT: ArtifactSchema,
ResourceType.ARTIFACT_VERSION: ArtifactVersionSchema,
ResourceType.SECRET: SecretSchema,
ResourceType.SERVICE: ServiceSchema,
ResourceType.TAG: TagSchema,
ResourceType.SERVICE_ACCOUNT: UserSchema,
# ResourceType.WORKSPACE: WorkspaceSchema,
ResourceType.PIPELINE_RUN: PipelineRunSchema,
ResourceType.PIPELINE_DEPLOYMENT: PipelineDeploymentSchema,
ResourceType.PIPELINE_BUILD: PipelineBuildSchema,
ResourceType.RUN_TEMPLATE: RunTemplateSchema,
ResourceType.RUN_METADATA: RunMetadataSchema,
# ResourceType.USER: UserSchema,
ResourceType.ACTION: ActionSchema,
ResourceType.EVENT_SOURCE: EventSourceSchema,
ResourceType.TRIGGER: TriggerSchema,
ResourceType.TRIGGER_EXECUTION: TriggerExecutionSchema,
}
return mapping[resource_type]
get_subresources_for_model(model)
Get all sub-resources of a model which need permission verification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyModel |
The model for which to get all the resources. |
required |
Returns:
Type | Description |
---|---|
Set[zenml.zen_server.rbac.models.Resource] |
All resources of a model which need permission verification. |
Source code in zenml/zen_server/rbac/utils.py
def get_subresources_for_model(
model: AnyModel,
) -> Set[Resource]:
"""Get all sub-resources of a model which need permission verification.
Args:
model: The model for which to get all the resources.
Returns:
All resources of a model which need permission verification.
"""
resources = set()
# We don't want to use `model.model_dump()` here as that recursively
# converts models to dicts, but we want to preserve those classes for
# the recursive `_get_subresources_for_value` calls.
# We previously used `dict(model)` here, but that lead to issues with
# models overwriting `__getattr__`, this `model.__iter__()` has the same
# results though.
if isinstance(model, Page):
for item in model:
resources.update(_get_subresources_for_value(item))
else:
for _, value in model.__iter__():
resources.update(_get_subresources_for_value(value))
return resources
get_surrogate_permission_model_for_model(model, action)
Get a surrogate permission model for a model.
In some cases a different model instead of the original model is used to verify permissions. For example, a parent container model might be used to verify permissions for all its children.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The original model. |
required |
action |
str |
The action that the user wants to perform on the model. |
required |
Returns:
Type | Description |
---|---|
BaseIdentifiedResponse[Any, Any, Any] |
A surrogate model or the original. |
Source code in zenml/zen_server/rbac/utils.py
def get_surrogate_permission_model_for_model(
model: AnyResponse, action: str
) -> BaseIdentifiedResponse[Any, Any, Any]:
"""Get a surrogate permission model for a model.
In some cases a different model instead of the original model is used to
verify permissions. For example, a parent container model might be used
to verify permissions for all its children.
Args:
model: The original model.
action: The action that the user wants to perform on the model.
Returns:
A surrogate model or the original.
"""
from zenml.models import ArtifactVersionResponse, ModelVersionResponse
# Permissions to read entities that represent versions of another entity
# are checked on the parent entity
if action == Action.READ:
if isinstance(model, ModelVersionResponse):
return model.model
elif isinstance(model, ArtifactVersionResponse):
return model.artifact
return model
has_permissions_for_model(model, action)
If the active user has permissions to perform the action on the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The model the user wants to perform the action on. |
required |
action |
Action |
The action the user wants to perform. |
required |
Returns:
Type | Description |
---|---|
bool |
If the active user has permissions to perform the action on the model. |
Source code in zenml/zen_server/rbac/utils.py
def has_permissions_for_model(model: AnyResponse, action: Action) -> bool:
"""If the active user has permissions to perform the action on the model.
Args:
model: The model the user wants to perform the action on.
action: The action the user wants to perform.
Returns:
If the active user has permissions to perform the action on the model.
"""
if is_owned_by_authenticated_user(model):
return True
try:
verify_permission_for_model(model=model, action=action)
return True
except IllegalOperationError:
return False
is_owned_by_authenticated_user(model)
Returns whether the currently authenticated user owns the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The model for which to check the ownership. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the currently authenticated user owns the model. |
Source code in zenml/zen_server/rbac/utils.py
def is_owned_by_authenticated_user(model: AnyResponse) -> bool:
"""Returns whether the currently authenticated user owns the model.
Args:
model: The model for which to check the ownership.
Returns:
Whether the currently authenticated user owns the model.
"""
auth_context = get_auth_context()
assert auth_context
if isinstance(model, UserScopedResponse):
if model.user:
return model.user.id == auth_context.user.id
else:
# The model is server-owned and for RBAC purposes we consider
# every user to be the owner of it
return True
return False
update_resource_membership(user, resource, actions)
Update the resource membership of a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User for which the resource membership should be updated. |
required |
resource |
Resource |
The resource. |
required |
actions |
List[zenml.zen_server.rbac.models.Action] |
The actions that the user should be able to perform on the resource. |
required |
Source code in zenml/zen_server/rbac/utils.py
def update_resource_membership(
user: UserResponse, resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.
Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
if not server_config().rbac_enabled:
return
rbac().update_resource_membership(
user=user, resource=resource, actions=actions
)
verify_permission(resource_type, action, resource_id=None)
Verifies if a user has permission to perform an action on a resource.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resource_type |
str |
The type of resource that the user wants to perform the action on. |
required |
action |
Action |
The action the user wants to perform. |
required |
resource_id |
Optional[uuid.UUID] |
ID of the resource the user wants to perform the action on. |
None |
Source code in zenml/zen_server/rbac/utils.py
def verify_permission(
resource_type: str,
action: Action,
resource_id: Optional[UUID] = None,
) -> None:
"""Verifies if a user has permission to perform an action on a resource.
Args:
resource_type: The type of resource that the user wants to perform the
action on.
action: The action the user wants to perform.
resource_id: ID of the resource the user wants to perform the action on.
"""
resource = Resource(type=resource_type, id=resource_id)
batch_verify_permissions(resources={resource}, action=action)
verify_permission_for_model(model, action)
Verifies if a user has permission to perform an action on a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
~AnyResponse |
The model the user wants to perform the action on. |
required |
action |
Action |
The action the user wants to perform. |
required |
Source code in zenml/zen_server/rbac/utils.py
def verify_permission_for_model(model: AnyResponse, action: Action) -> None:
"""Verifies if a user has permission to perform an action on a model.
Args:
model: The model the user wants to perform the action on.
action: The action the user wants to perform.
"""
batch_verify_permissions_for_models(models=[model], action=action)
zenml_cloud_rbac
Cloud RBAC implementation.
ZenMLCloudRBAC (RBACInterface)
RBAC implementation that uses the ZenML Pro Management Plane as a backend.
Source code in zenml/zen_server/rbac/zenml_cloud_rbac.py
class ZenMLCloudRBAC(RBACInterface):
"""RBAC implementation that uses the ZenML Pro Management Plane as a backend."""
def __init__(self) -> None:
"""Initialize the object."""
self._connection = cloud_connection()
def check_permissions(
self, user: "UserResponse", resources: Set[Resource], action: Action
) -> Dict[Resource, bool]:
"""Checks if a user has permissions to perform an action on resources.
Args:
user: User which wants to access a resource.
resources: The resources the user wants to access.
action: The action that the user wants to perform on the resources.
Returns:
A dictionary mapping resources to a boolean which indicates whether
the user has permissions to perform the action on that resource.
"""
if not resources:
# No need to send a request if there are no resources
return {}
if user.is_service_account:
# Service accounts have full permissions for now
return {resource: True for resource in resources}
# At this point it's a regular user, which in a ZenML Pro with RBAC
# enabled is always authenticated using external authentication
assert user.external_user_id
params = {
"user_id": str(user.external_user_id),
"resources": [
_convert_to_cloud_resource(resource) for resource in resources
],
"action": str(action),
}
response = self._connection.get(
endpoint=PERMISSIONS_ENDPOINT, params=params
)
value = response.json()
assert isinstance(value, dict)
return {_convert_from_cloud_resource(k): v for k, v in value.items()}
def list_allowed_resource_ids(
self, user: "UserResponse", resource: Resource, action: Action
) -> Tuple[bool, List[str]]:
"""Lists all resource IDs of a resource type that a user can access.
Args:
user: User which wants to access a resource.
resource: The resource the user wants to access.
action: The action that the user wants to perform on the resource.
Returns:
A tuple (full_resource_access, resource_ids).
`full_resource_access` will be `True` if the user can perform the
given action on any instance of the given resource type, `False`
otherwise. If `full_resource_access` is `False`, `resource_ids`
will contain the list of instance IDs that the user can perform
the action on.
"""
assert not resource.id
if user.is_service_account:
# Service accounts have full permissions for now
return True, []
# At this point it's a regular user, which in the ZenML Pro with RBAC
# enabled is always authenticated using external authentication
assert user.external_user_id
params = {
"user_id": str(user.external_user_id),
"resource": _convert_to_cloud_resource(resource),
"action": str(action),
}
response = self._connection.get(
endpoint=ALLOWED_RESOURCE_IDS_ENDPOINT, params=params
)
response_json = response.json()
full_resource_access: bool = response_json["full_access"]
allowed_ids: List[str] = response_json["ids"]
return full_resource_access, allowed_ids
def update_resource_membership(
self, user: "UserResponse", resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.
Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
if user.is_service_account:
# Service accounts have full permissions for now
return
data = {
"user_id": str(user.external_user_id),
"resource": _convert_to_cloud_resource(resource),
"actions": [str(action) for action in actions],
}
self._connection.post(endpoint=RESOURCE_MEMBERSHIP_ENDPOINT, data=data)
__init__(self)
special
Initialize the object.
Source code in zenml/zen_server/rbac/zenml_cloud_rbac.py
def __init__(self) -> None:
"""Initialize the object."""
self._connection = cloud_connection()
check_permissions(self, user, resources, action)
Checks if a user has permissions to perform an action on resources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User which wants to access a resource. |
required |
resources |
Set[zenml.zen_server.rbac.models.Resource] |
The resources the user wants to access. |
required |
action |
Action |
The action that the user wants to perform on the resources. |
required |
Returns:
Type | Description |
---|---|
Dict[zenml.zen_server.rbac.models.Resource, bool] |
A dictionary mapping resources to a boolean which indicates whether the user has permissions to perform the action on that resource. |
Source code in zenml/zen_server/rbac/zenml_cloud_rbac.py
def check_permissions(
self, user: "UserResponse", resources: Set[Resource], action: Action
) -> Dict[Resource, bool]:
"""Checks if a user has permissions to perform an action on resources.
Args:
user: User which wants to access a resource.
resources: The resources the user wants to access.
action: The action that the user wants to perform on the resources.
Returns:
A dictionary mapping resources to a boolean which indicates whether
the user has permissions to perform the action on that resource.
"""
if not resources:
# No need to send a request if there are no resources
return {}
if user.is_service_account:
# Service accounts have full permissions for now
return {resource: True for resource in resources}
# At this point it's a regular user, which in a ZenML Pro with RBAC
# enabled is always authenticated using external authentication
assert user.external_user_id
params = {
"user_id": str(user.external_user_id),
"resources": [
_convert_to_cloud_resource(resource) for resource in resources
],
"action": str(action),
}
response = self._connection.get(
endpoint=PERMISSIONS_ENDPOINT, params=params
)
value = response.json()
assert isinstance(value, dict)
return {_convert_from_cloud_resource(k): v for k, v in value.items()}
list_allowed_resource_ids(self, user, resource, action)
Lists all resource IDs of a resource type that a user can access.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User which wants to access a resource. |
required |
resource |
Resource |
The resource the user wants to access. |
required |
action |
Action |
The action that the user wants to perform on the resource. |
required |
Returns:
Type | Description |
---|---|
Tuple[bool, List[str]] |
A tuple (full_resource_access, resource_ids).
|
Source code in zenml/zen_server/rbac/zenml_cloud_rbac.py
def list_allowed_resource_ids(
self, user: "UserResponse", resource: Resource, action: Action
) -> Tuple[bool, List[str]]:
"""Lists all resource IDs of a resource type that a user can access.
Args:
user: User which wants to access a resource.
resource: The resource the user wants to access.
action: The action that the user wants to perform on the resource.
Returns:
A tuple (full_resource_access, resource_ids).
`full_resource_access` will be `True` if the user can perform the
given action on any instance of the given resource type, `False`
otherwise. If `full_resource_access` is `False`, `resource_ids`
will contain the list of instance IDs that the user can perform
the action on.
"""
assert not resource.id
if user.is_service_account:
# Service accounts have full permissions for now
return True, []
# At this point it's a regular user, which in the ZenML Pro with RBAC
# enabled is always authenticated using external authentication
assert user.external_user_id
params = {
"user_id": str(user.external_user_id),
"resource": _convert_to_cloud_resource(resource),
"action": str(action),
}
response = self._connection.get(
endpoint=ALLOWED_RESOURCE_IDS_ENDPOINT, params=params
)
response_json = response.json()
full_resource_access: bool = response_json["full_access"]
allowed_ids: List[str] = response_json["ids"]
return full_resource_access, allowed_ids
update_resource_membership(self, user, resource, actions)
Update the resource membership of a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserResponse |
User for which the resource membership should be updated. |
required |
resource |
Resource |
The resource. |
required |
actions |
List[zenml.zen_server.rbac.models.Action] |
The actions that the user should be able to perform on the resource. |
required |
Source code in zenml/zen_server/rbac/zenml_cloud_rbac.py
def update_resource_membership(
self, user: "UserResponse", resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.
Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
if user.is_service_account:
# Service accounts have full permissions for now
return
data = {
"user_id": str(user.external_user_id),
"resource": _convert_to_cloud_resource(resource),
"actions": [str(action) for action in actions],
}
self._connection.post(endpoint=RESOURCE_MEMBERSHIP_ENDPOINT, data=data)
routers
special
Endpoint definitions.
actions_endpoints
Endpoint definitions for actions.
create_action(action, _=Security(oauth2_authentication))
Creates an action.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action |
ActionRequest |
Action to create. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the action handler for flavor/type is not valid. |
Returns:
Type | Description |
---|---|
ActionResponse |
The created action. |
Source code in zenml/zen_server/routers/actions_endpoints.py
@router.post(
"",
response_model=ActionResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_action(
action: ActionRequest,
_: AuthContext = Security(authorize),
) -> ActionResponse:
"""Creates an action.
Args:
action: Action to create.
Raises:
ValueError: If the action handler for flavor/type is not valid.
Returns:
The created action.
"""
service_account = zen_store().get_service_account(
service_account_name_or_id=action.service_account_id
)
verify_permission_for_model(service_account, action=Action.READ)
action_handler = plugin_flavor_registry().get_plugin(
name=action.flavor,
_type=PluginType.ACTION,
subtype=action.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an action
# handler implementation
if not isinstance(action_handler, BaseActionHandler):
raise ValueError(
f"Action handler plugin {action.plugin_subtype} "
f"for flavor {action.flavor} is not a valid action "
"handler plugin."
)
return verify_permissions_and_create_entity(
request_model=action,
resource_type=ResourceType.ACTION,
create_method=action_handler.create_action,
)
delete_action(action_id, force=False, _=Security(oauth2_authentication))
Delete an action.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_id |
UUID |
ID of the action. |
required |
force |
bool |
Flag deciding whether to force delete the action. |
False |
Exceptions:
Type | Description |
---|---|
ValueError |
If the action handler for flavor/type is not valid. |
Source code in zenml/zen_server/routers/actions_endpoints.py
@router.delete(
"/{action_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_action(
action_id: UUID,
force: bool = False,
_: AuthContext = Security(authorize),
) -> None:
"""Delete an action.
Args:
action_id: ID of the action.
force: Flag deciding whether to force delete the action.
Raises:
ValueError: If the action handler for flavor/type is not valid.
"""
action = zen_store().get_action(action_id=action_id)
verify_permission_for_model(action, action=Action.DELETE)
action_handler = plugin_flavor_registry().get_plugin(
name=action.flavor,
_type=PluginType.ACTION,
subtype=action.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an action
# handler implementation
if not isinstance(action_handler, BaseActionHandler):
raise ValueError(
f"Action handler plugin {action.plugin_subtype} "
f"for flavor {action.flavor} is not a valid action "
"handler plugin."
)
action_handler.delete_action(
action=action,
force=force,
)
get_action(action_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested action.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_id |
UUID |
ID of the action. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Exceptions:
Type | Description |
---|---|
ValueError |
If the action handler for flavor/type is not valid. |
Returns:
Type | Description |
---|---|
ActionResponse |
The requested action. |
Source code in zenml/zen_server/routers/actions_endpoints.py
@router.get(
"/{action_id}",
response_model=ActionResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_action(
action_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ActionResponse:
"""Returns the requested action.
Args:
action_id: ID of the action.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Raises:
ValueError: If the action handler for flavor/type is not valid.
Returns:
The requested action.
"""
action = zen_store().get_action(action_id=action_id, hydrate=hydrate)
verify_permission_for_model(action, action=Action.READ)
action_handler = plugin_flavor_registry().get_plugin(
name=action.flavor,
_type=PluginType.ACTION,
subtype=action.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an action
# handler implementation
if not isinstance(action_handler, BaseActionHandler):
raise ValueError(
f"Action handler plugin {action.plugin_subtype} "
f"for flavor {action.flavor} is not a valid action "
"handler plugin."
)
action = action_handler.get_action(action, hydrate=hydrate)
return dehydrate_response_model(action)
list_actions(action_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
List actions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_filter_model |
ActionFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ActionResponse] |
Page of actions. |
Source code in zenml/zen_server/routers/actions_endpoints.py
@router.get(
"",
response_model=Page[ActionResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_actions(
action_filter_model: ActionFilter = Depends(make_dependable(ActionFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ActionResponse]:
"""List actions.
Args:
action_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page of actions.
"""
def list_actions_fn(
filter_model: ActionFilter,
) -> Page[ActionResponse]:
"""List actions through their associated plugins.
Args:
filter_model: Filter model used for pagination, sorting,
filtering.
Raises:
ValueError: If the action handler for flavor/type is not valid.
Returns:
All actions.
"""
actions = zen_store().list_actions(
action_filter_model=filter_model, hydrate=hydrate
)
# Process the actions through their associated plugins
for idx, action in enumerate(actions.items):
action_handler = plugin_flavor_registry().get_plugin(
name=action.flavor,
_type=PluginType.ACTION,
subtype=action.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an action
# handler implementation
if not isinstance(action_handler, BaseActionHandler):
raise ValueError(
f"Action handler plugin {action.plugin_subtype} "
f"for flavor {action.flavor} is not a valid action "
"handler plugin."
)
actions.items[idx] = action_handler.get_action(
action, hydrate=hydrate
)
return actions
return verify_permissions_and_list_entities(
filter_model=action_filter_model,
resource_type=ResourceType.ACTION,
list_method=list_actions_fn,
)
update_action(action_id, action_update, _=Security(oauth2_authentication))
Update an action.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action_id |
UUID |
ID of the action to update. |
required |
action_update |
ActionUpdate |
The action update. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the action handler for flavor/type is not valid. |
Returns:
Type | Description |
---|---|
ActionResponse |
The updated action. |
Source code in zenml/zen_server/routers/actions_endpoints.py
@router.put(
"/{action_id}",
response_model=ActionResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_action(
action_id: UUID,
action_update: ActionUpdate,
_: AuthContext = Security(authorize),
) -> ActionResponse:
"""Update an action.
Args:
action_id: ID of the action to update.
action_update: The action update.
Raises:
ValueError: If the action handler for flavor/type is not valid.
Returns:
The updated action.
"""
action = zen_store().get_action(action_id=action_id)
verify_permission_for_model(action, action=Action.UPDATE)
if action_update.service_account_id:
service_account = zen_store().get_service_account(
service_account_name_or_id=action_update.service_account_id
)
verify_permission_for_model(service_account, action=Action.READ)
action_handler = plugin_flavor_registry().get_plugin(
name=action.flavor,
_type=PluginType.ACTION,
subtype=action.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an action
# handler implementation
if not isinstance(action_handler, BaseActionHandler):
raise ValueError(
f"Action handler plugin {action.plugin_subtype} "
f"for flavor {action.flavor} is not a valid action "
"handler plugin."
)
updated_action = action_handler.update_action(
action=action,
action_update=action_update,
)
return dehydrate_response_model(updated_action)
artifact_endpoint
Endpoint definitions for artifacts.
create_artifact(artifact, _=Security(oauth2_authentication))
Create a new artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactRequest |
The artifact to create. |
required |
Returns:
Type | Description |
---|---|
ArtifactResponse |
The created artifact. |
Source code in zenml/zen_server/routers/artifact_endpoint.py
@artifact_router.post(
"",
response_model=ArtifactResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_artifact(
artifact: ArtifactRequest,
_: AuthContext = Security(authorize),
) -> ArtifactResponse:
"""Create a new artifact.
Args:
artifact: The artifact to create.
Returns:
The created artifact.
"""
return verify_permissions_and_create_entity(
request_model=artifact,
resource_type=ResourceType.ARTIFACT,
create_method=zen_store().create_artifact,
)
delete_artifact(artifact_id, _=Security(oauth2_authentication))
Delete an artifact by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_id |
UUID |
The ID of the artifact to delete. |
required |
Source code in zenml/zen_server/routers/artifact_endpoint.py
@artifact_router.delete(
"/{artifact_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_artifact(
artifact_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Delete an artifact by ID.
Args:
artifact_id: The ID of the artifact to delete.
"""
verify_permissions_and_delete_entity(
id=artifact_id,
get_method=zen_store().get_artifact,
delete_method=zen_store().delete_artifact,
)
get_artifact(artifact_id, hydrate=True, _=Security(oauth2_authentication))
Get an artifact by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_id |
UUID |
The ID of the artifact to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ArtifactResponse |
The artifact with the given ID. |
Source code in zenml/zen_server/routers/artifact_endpoint.py
@artifact_router.get(
"/{artifact_id}",
response_model=ArtifactResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_artifact(
artifact_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ArtifactResponse:
"""Get an artifact by ID.
Args:
artifact_id: The ID of the artifact to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The artifact with the given ID.
"""
return verify_permissions_and_get_entity(
id=artifact_id,
get_method=zen_store().get_artifact,
hydrate=hydrate,
)
list_artifacts(artifact_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get artifacts according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_filter_model |
ArtifactFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ArtifactResponse] |
The artifacts according to query filters. |
Source code in zenml/zen_server/routers/artifact_endpoint.py
@artifact_router.get(
"",
response_model=Page[ArtifactResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_artifacts(
artifact_filter_model: ArtifactFilter = Depends(
make_dependable(ArtifactFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ArtifactResponse]:
"""Get artifacts according to query filters.
Args:
artifact_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The artifacts according to query filters.
"""
return verify_permissions_and_list_entities(
filter_model=artifact_filter_model,
resource_type=ResourceType.ARTIFACT,
list_method=zen_store().list_artifacts,
hydrate=hydrate,
)
update_artifact(artifact_id, artifact_update, _=Security(oauth2_authentication))
Update an artifact by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_id |
UUID |
The ID of the artifact to update. |
required |
artifact_update |
ArtifactUpdate |
The update to apply to the artifact. |
required |
Returns:
Type | Description |
---|---|
ArtifactResponse |
The updated artifact. |
Source code in zenml/zen_server/routers/artifact_endpoint.py
@artifact_router.put(
"/{artifact_id}",
response_model=ArtifactResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_artifact(
artifact_id: UUID,
artifact_update: ArtifactUpdate,
_: AuthContext = Security(authorize),
) -> ArtifactResponse:
"""Update an artifact by ID.
Args:
artifact_id: The ID of the artifact to update.
artifact_update: The update to apply to the artifact.
Returns:
The updated artifact.
"""
return verify_permissions_and_update_entity(
id=artifact_id,
update_model=artifact_update,
get_method=zen_store().get_artifact,
update_method=zen_store().update_artifact,
)
artifact_version_endpoints
Endpoint definitions for artifact versions.
batch_create_artifact_version(artifact_versions, _=Security(oauth2_authentication))
Create a batch of artifact versions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_versions |
List[zenml.models.v2.core.artifact_version.ArtifactVersionRequest] |
The artifact versions to create. |
required |
Returns:
Type | Description |
---|---|
List[zenml.models.v2.core.artifact_version.ArtifactVersionResponse] |
The created artifact versions. |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.post(
BATCH,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def batch_create_artifact_version(
artifact_versions: List[ArtifactVersionRequest],
_: AuthContext = Security(authorize),
) -> List[ArtifactVersionResponse]:
"""Create a batch of artifact versions.
Args:
artifact_versions: The artifact versions to create.
Returns:
The created artifact versions.
"""
return verify_permissions_and_batch_create_entity(
batch=artifact_versions,
resource_type=ResourceType.ARTIFACT_VERSION,
create_method=zen_store().batch_create_artifact_versions,
)
create_artifact_version(artifact_version, _=Security(oauth2_authentication))
Create a new artifact version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version |
ArtifactVersionRequest |
The artifact version to create. |
required |
Returns:
Type | Description |
---|---|
ArtifactVersionResponse |
The created artifact version. |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.post(
"",
response_model=ArtifactVersionResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_artifact_version(
artifact_version: ArtifactVersionRequest,
_: AuthContext = Security(authorize),
) -> ArtifactVersionResponse:
"""Create a new artifact version.
Args:
artifact_version: The artifact version to create.
Returns:
The created artifact version.
"""
return verify_permissions_and_create_entity(
request_model=artifact_version,
resource_type=ResourceType.ARTIFACT_VERSION,
create_method=zen_store().create_artifact_version,
)
delete_artifact_version(artifact_version_id, _=Security(oauth2_authentication))
Delete an artifact version by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_id |
UUID |
The ID of the artifact version to delete. |
required |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.delete(
"/{artifact_version_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_artifact_version(
artifact_version_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Delete an artifact version by ID.
Args:
artifact_version_id: The ID of the artifact version to delete.
"""
verify_permissions_and_delete_entity(
id=artifact_version_id,
get_method=zen_store().get_artifact_version,
delete_method=zen_store().delete_artifact_version,
)
get_artifact_version(artifact_version_id, hydrate=True, _=Security(oauth2_authentication))
Get an artifact version by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_id |
UUID |
The ID of the artifact version to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ArtifactVersionResponse |
The artifact version with the given ID. |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.get(
"/{artifact_version_id}",
response_model=ArtifactVersionResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_artifact_version(
artifact_version_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ArtifactVersionResponse:
"""Get an artifact version by ID.
Args:
artifact_version_id: The ID of the artifact version to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The artifact version with the given ID.
"""
return verify_permissions_and_get_entity(
id=artifact_version_id,
get_method=zen_store().get_artifact_version,
hydrate=hydrate,
)
get_artifact_visualization(artifact_version_id, index=0, _=Security(oauth2_authentication))
Get the visualization of an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_id |
UUID |
ID of the artifact version for which to get the visualization. |
required |
index |
int |
Index of the visualization to get (if there are multiple). |
0 |
Returns:
Type | Description |
---|---|
LoadedVisualization |
The visualization of the artifact version. |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.get(
"/{artifact_version_id}" + VISUALIZE,
response_model=LoadedVisualization,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_artifact_visualization(
artifact_version_id: UUID,
index: int = 0,
_: AuthContext = Security(authorize),
) -> LoadedVisualization:
"""Get the visualization of an artifact.
Args:
artifact_version_id: ID of the artifact version for which to get the visualization.
index: Index of the visualization to get (if there are multiple).
Returns:
The visualization of the artifact version.
"""
store = zen_store()
artifact = verify_permissions_and_get_entity(
id=artifact_version_id, get_method=store.get_artifact_version
)
return load_artifact_visualization(
artifact=artifact, index=index, zen_store=store, encode_image=True
)
list_artifact_versions(artifact_version_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, auth_context=Security(oauth2_authentication))
Get artifact versions according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_filter_model |
ArtifactVersionFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ArtifactVersionResponse] |
The artifact versions according to query filters. |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.get(
"",
response_model=Page[ArtifactVersionResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_artifact_versions(
artifact_version_filter_model: ArtifactVersionFilter = Depends(
make_dependable(ArtifactVersionFilter)
),
hydrate: bool = False,
auth_context: AuthContext = Security(authorize),
) -> Page[ArtifactVersionResponse]:
"""Get artifact versions according to query filters.
Args:
artifact_version_filter_model: Filter model used for pagination,
sorting, filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: The authentication context.
Returns:
The artifact versions according to query filters.
"""
allowed_artifact_ids = get_allowed_resource_ids(
resource_type=ResourceType.ARTIFACT
)
artifact_version_filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id,
artifact_id=allowed_artifact_ids,
)
artifact_versions = zen_store().list_artifact_versions(
artifact_version_filter_model=artifact_version_filter_model,
hydrate=hydrate,
)
return dehydrate_page(artifact_versions)
prune_artifact_versions(only_versions=True, _=Security(oauth2_authentication))
Prunes unused artifact versions and their artifacts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
only_versions |
bool |
Only delete artifact versions, keeping artifacts |
True |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.delete(
"",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def prune_artifact_versions(
only_versions: bool = True,
_: AuthContext = Security(authorize),
) -> None:
"""Prunes unused artifact versions and their artifacts.
Args:
only_versions: Only delete artifact versions, keeping artifacts
"""
verify_permissions_and_prune_entities(
resource_type=ResourceType.ARTIFACT_VERSION,
prune_method=zen_store().prune_artifact_versions,
only_versions=only_versions,
)
update_artifact_version(artifact_version_id, artifact_version_update, _=Security(oauth2_authentication))
Update an artifact by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_id |
UUID |
The ID of the artifact version to update. |
required |
artifact_version_update |
ArtifactVersionUpdate |
The update to apply to the artifact version. |
required |
Returns:
Type | Description |
---|---|
ArtifactVersionResponse |
The updated artifact. |
Source code in zenml/zen_server/routers/artifact_version_endpoints.py
@artifact_version_router.put(
"/{artifact_version_id}",
response_model=ArtifactVersionResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_artifact_version(
artifact_version_id: UUID,
artifact_version_update: ArtifactVersionUpdate,
_: AuthContext = Security(authorize),
) -> ArtifactVersionResponse:
"""Update an artifact by ID.
Args:
artifact_version_id: The ID of the artifact version to update.
artifact_version_update: The update to apply to the artifact version.
Returns:
The updated artifact.
"""
return verify_permissions_and_update_entity(
id=artifact_version_id,
update_model=artifact_version_update,
get_method=zen_store().get_artifact_version,
update_method=zen_store().update_artifact_version,
)
auth_endpoints
Endpoint definitions for authentication (login).
OAuthLoginRequestForm
OAuth2 grant type request form.
This form allows multiple grant types to be used with the same endpoint: * standard OAuth2 password grant type * standard OAuth2 device authorization grant type * ZenML service account + API key grant type (proprietary) * ZenML External Authenticator grant type (proprietary)
Source code in zenml/zen_server/routers/auth_endpoints.py
class OAuthLoginRequestForm:
"""OAuth2 grant type request form.
This form allows multiple grant types to be used with the same endpoint:
* standard OAuth2 password grant type
* standard OAuth2 device authorization grant type
* ZenML service account + API key grant type (proprietary)
* ZenML External Authenticator grant type (proprietary)
"""
def __init__(
self,
grant_type: Optional[str] = Form(None),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
client_id: Optional[str] = Form(None),
device_code: Optional[str] = Form(None),
):
"""Initializes the form.
Args:
grant_type: The grant type.
username: The username. Only used for the password grant type.
password: The password. Only used for the password grant type.
client_id: The client ID.
device_code: The device code. Only used for the device authorization
grant type.
Raises:
HTTPException: If the request is invalid.
"""
config = server_config()
if not grant_type:
# Detect the grant type from the form data
if username is not None:
self.grant_type = OAuthGrantTypes.OAUTH_PASSWORD
elif password:
self.grant_type = OAuthGrantTypes.ZENML_API_KEY
elif device_code:
self.grant_type = OAuthGrantTypes.OAUTH_DEVICE_CODE
elif config.auth_scheme == AuthScheme.EXTERNAL:
self.grant_type = OAuthGrantTypes.ZENML_EXTERNAL
elif config.auth_scheme in [
AuthScheme.OAUTH2_PASSWORD_BEARER,
AuthScheme.NO_AUTH,
AuthScheme.HTTP_BASIC,
]:
# For no auth and basic HTTP auth schemes, we also allow the
# password grant type to be used for backwards compatibility
self.grant_type = OAuthGrantTypes.OAUTH_PASSWORD
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: grant type is required.",
)
else:
if grant_type not in OAuthGrantTypes.values():
logger.info(
f"Request with unsupported grant type: {grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {grant_type}",
)
self.grant_type = OAuthGrantTypes(grant_type)
if self.grant_type == OAuthGrantTypes.OAUTH_PASSWORD:
# For the no auth and basic HTTP auth schemes, we also allow the
# password grant type to be used for compatibility with other
# auth schemes
if config.auth_scheme not in [
AuthScheme.OAUTH2_PASSWORD_BEARER,
AuthScheme.NO_AUTH,
AuthScheme.HTTP_BASIC,
]:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
if not username:
logger.info("Request with missing username")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: username is required.",
)
self.username = username
self.password = password or ""
elif self.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE:
if not device_code or not client_id:
logger.info("Request with missing device code or client ID")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: device code and client ID are "
"required.",
)
try:
self.client_id = UUID(client_id)
except ValueError:
logger.info(f"Request with invalid client ID: {client_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: invalid client ID.",
)
self.device_code = device_code
elif self.grant_type == OAuthGrantTypes.ZENML_API_KEY:
if not password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key is required.",
)
self.api_key = password
elif self.grant_type == OAuthGrantTypes.ZENML_EXTERNAL:
if config.auth_scheme != AuthScheme.EXTERNAL:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
__init__(self, grant_type=Form(None), username=Form(None), password=Form(None), client_id=Form(None), device_code=Form(None))
special
Initializes the form.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
grant_type |
Optional[str] |
The grant type. |
Form(None) |
username |
Optional[str] |
The username. Only used for the password grant type. |
Form(None) |
password |
Optional[str] |
The password. Only used for the password grant type. |
Form(None) |
client_id |
Optional[str] |
The client ID. |
Form(None) |
device_code |
Optional[str] |
The device code. Only used for the device authorization grant type. |
Form(None) |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the request is invalid. |
Source code in zenml/zen_server/routers/auth_endpoints.py
def __init__(
self,
grant_type: Optional[str] = Form(None),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
client_id: Optional[str] = Form(None),
device_code: Optional[str] = Form(None),
):
"""Initializes the form.
Args:
grant_type: The grant type.
username: The username. Only used for the password grant type.
password: The password. Only used for the password grant type.
client_id: The client ID.
device_code: The device code. Only used for the device authorization
grant type.
Raises:
HTTPException: If the request is invalid.
"""
config = server_config()
if not grant_type:
# Detect the grant type from the form data
if username is not None:
self.grant_type = OAuthGrantTypes.OAUTH_PASSWORD
elif password:
self.grant_type = OAuthGrantTypes.ZENML_API_KEY
elif device_code:
self.grant_type = OAuthGrantTypes.OAUTH_DEVICE_CODE
elif config.auth_scheme == AuthScheme.EXTERNAL:
self.grant_type = OAuthGrantTypes.ZENML_EXTERNAL
elif config.auth_scheme in [
AuthScheme.OAUTH2_PASSWORD_BEARER,
AuthScheme.NO_AUTH,
AuthScheme.HTTP_BASIC,
]:
# For no auth and basic HTTP auth schemes, we also allow the
# password grant type to be used for backwards compatibility
self.grant_type = OAuthGrantTypes.OAUTH_PASSWORD
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: grant type is required.",
)
else:
if grant_type not in OAuthGrantTypes.values():
logger.info(
f"Request with unsupported grant type: {grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {grant_type}",
)
self.grant_type = OAuthGrantTypes(grant_type)
if self.grant_type == OAuthGrantTypes.OAUTH_PASSWORD:
# For the no auth and basic HTTP auth schemes, we also allow the
# password grant type to be used for compatibility with other
# auth schemes
if config.auth_scheme not in [
AuthScheme.OAUTH2_PASSWORD_BEARER,
AuthScheme.NO_AUTH,
AuthScheme.HTTP_BASIC,
]:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
if not username:
logger.info("Request with missing username")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: username is required.",
)
self.username = username
self.password = password or ""
elif self.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE:
if not device_code or not client_id:
logger.info("Request with missing device code or client ID")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: device code and client ID are "
"required.",
)
try:
self.client_id = UUID(client_id)
except ValueError:
logger.info(f"Request with invalid client ID: {client_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: invalid client ID.",
)
self.device_code = device_code
elif self.grant_type == OAuthGrantTypes.ZENML_API_KEY:
if not password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key is required.",
)
self.api_key = password
elif self.grant_type == OAuthGrantTypes.ZENML_EXTERNAL:
if config.auth_scheme != AuthScheme.EXTERNAL:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
api_token(token_type=<APITokenType.GENERIC: 'generic'>, expires_in=None, schedule_id=None, pipeline_run_id=None, step_run_id=None, auth_context=Security(oauth2_authentication))
Generate an API token for the current user.
Use this endpoint to generate an API token for the current user. Two types of API tokens are supported:
- Generic API token: This token is short-lived and can be used for generic automation tasks. The expiration can be set by the user, but the server will impose a maximum expiration time.
- Workload API token: This token is scoped to a specific pipeline run, step run or schedule and is used by pipeline workloads to authenticate with the server. A pipeline run ID, step run ID or schedule ID must be provided and the generated token will only be valid for the indicated pipeline run, step run or schedule. No time limit is imposed on the validity of the token. A workload API token can be used to authenticate and generate another workload API token, but only for the same schedule, pipeline run ID or step run ID, in that order.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token_type |
APITokenType |
The type of API token to generate. |
<APITokenType.GENERIC: 'generic'> |
expires_in |
Optional[int] |
The expiration time of the generic API token in seconds. If not set, the server will use the default expiration time for generic API tokens. The server also imposes a maximum expiration time. |
None |
schedule_id |
Optional[uuid.UUID] |
The ID of the schedule to scope the workload API token to. |
None |
pipeline_run_id |
Optional[uuid.UUID] |
The ID of the pipeline run to scope the workload API token to. |
None |
step_run_id |
Optional[uuid.UUID] |
The ID of the step run to scope the workload API token to. |
None |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
str |
The API token. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If not authorized to generate the API token. |
ValueError |
If the request is invalid. |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.get(
API_TOKEN,
response_model=str,
)
@handle_exceptions
def api_token(
token_type: APITokenType = APITokenType.GENERIC,
expires_in: Optional[int] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
auth_context: AuthContext = Security(authorize),
) -> str:
"""Generate an API token for the current user.
Use this endpoint to generate an API token for the current user. Two types
of API tokens are supported:
* Generic API token: This token is short-lived and can be used for
generic automation tasks. The expiration can be set by the user, but the
server will impose a maximum expiration time.
* Workload API token: This token is scoped to a specific pipeline run, step
run or schedule and is used by pipeline workloads to authenticate with the
server. A pipeline run ID, step run ID or schedule ID must be provided and
the generated token will only be valid for the indicated pipeline run, step
run or schedule. No time limit is imposed on the validity of the token.
A workload API token can be used to authenticate and generate another
workload API token, but only for the same schedule, pipeline run ID or step
run ID, in that order.
Args:
token_type: The type of API token to generate.
expires_in: The expiration time of the generic API token in seconds.
If not set, the server will use the default expiration time for
generic API tokens. The server also imposes a maximum expiration
time.
schedule_id: The ID of the schedule to scope the workload API token to.
pipeline_run_id: The ID of the pipeline run to scope the workload API
token to.
step_run_id: The ID of the step run to scope the workload API token to.
auth_context: The authentication context.
Returns:
The API token.
Raises:
AuthorizationException: If not authorized to generate the API token.
ValueError: If the request is invalid.
"""
token = auth_context.access_token
if not token or not auth_context.encoded_access_token:
# Should not happen
raise AuthorizationException("Not authenticated.")
if token_type == APITokenType.GENERIC:
if schedule_id or pipeline_run_id or step_run_id:
raise ValueError(
"Generic API tokens cannot be scoped to a schedule, pipeline "
"run or step run."
)
config = server_config()
if not expires_in:
expires_in = config.generic_api_token_lifetime
if expires_in > config.generic_api_token_max_lifetime:
raise ValueError(
f"The maximum expiration time for generic API tokens allowed "
f"by this server is {config.generic_api_token_max_lifetime} "
"seconds."
)
return generate_access_token(
user_id=token.user_id,
expires_in=expires_in,
).access_token
verify_permission(
resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE
)
schedule_id = schedule_id or token.schedule_id
pipeline_run_id = pipeline_run_id or token.pipeline_run_id
step_run_id = step_run_id or token.step_run_id
if not pipeline_run_id and not schedule_id and not step_run_id:
raise ValueError(
"Workload API tokens must be scoped to a schedule, pipeline run "
"or step run."
)
if schedule_id and token.schedule_id and schedule_id != token.schedule_id:
raise AuthorizationException(
f"Unable to scope API token to schedule {schedule_id}. The "
f"token used to authorize this request is already scoped to "
f"schedule {token.schedule_id}."
)
if (
pipeline_run_id
and token.pipeline_run_id
and pipeline_run_id != token.pipeline_run_id
):
raise AuthorizationException(
f"Unable to scope API token to pipeline run {pipeline_run_id}. The "
f"token used to authorize this request is already scoped to "
f"pipeline run {token.pipeline_run_id}."
)
if step_run_id and token.step_run_id and step_run_id != token.step_run_id:
raise AuthorizationException(
f"Unable to scope API token to step run {step_run_id}. The "
f"token used to authorize this request is already scoped to "
f"step run {token.step_run_id}."
)
if schedule_id:
# The schedule must exist
try:
schedule = zen_store().get_schedule(schedule_id, hydrate=False)
except KeyError:
raise ValueError(
f"Schedule {schedule_id} does not exist and API tokens cannot "
"be generated for non-existent schedules for security reasons."
)
if not schedule.active:
raise ValueError(
f"Schedule {schedule_id} is not active and API tokens cannot "
"be generated for inactive schedules for security reasons."
)
if pipeline_run_id:
# The pipeline run must exist and the run must not be concluded
try:
pipeline_run = zen_store().get_run(pipeline_run_id, hydrate=False)
except KeyError:
raise ValueError(
f"Pipeline run {pipeline_run_id} does not exist and API tokens "
"cannot be generated for non-existent pipeline runs for "
"security reasons."
)
if pipeline_run.status.is_finished:
raise ValueError(
f"The execution of pipeline run {pipeline_run_id} has already "
"concluded and API tokens can no longer be generated for it "
"for security reasons."
)
if step_run_id:
# The step run must exist and the step must not be concluded
try:
step_run = zen_store().get_run_step(step_run_id, hydrate=False)
except KeyError:
raise ValueError(
f"Step run {step_run_id} does not exist and API tokens cannot "
"be generated for non-existent step runs for security reasons."
)
if step_run.status.is_finished:
raise ValueError(
f"The execution of step run {step_run_id} has already "
"concluded and API tokens can no longer be generated for it "
"for security reasons."
)
return generate_access_token(
user_id=token.user_id,
# Keep the original API key and device token scopes
api_key=auth_context.api_key,
device=auth_context.device,
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
# Never expire the token
expires_in=0,
).access_token
device_authorization(request, client_id=Form(PydanticUndefined))
OAuth2 device authorization endpoint.
This endpoint implements the OAuth2 device authorization grant flow as defined in https://tools.ietf.org/html/rfc8628. It is called to initiate the device authorization flow by requesting a device and user code for a given client ID.
For a new client ID, a new OAuth device is created, stored in the DB and returned to the client along with a pair of newly generated device and user codes. If a device for the given client ID already exists, the existing DB entry is reused and new device and user codes are generated.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request object. |
required |
client_id |
UUID |
The client ID. |
Form(PydanticUndefined) |
Returns:
Type | Description |
---|---|
OAuthDeviceAuthorizationResponse |
The device authorization response. |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.post(
DEVICE_AUTHORIZATION,
response_model=OAuthDeviceAuthorizationResponse,
)
def device_authorization(
request: Request,
client_id: UUID = Form(...),
) -> OAuthDeviceAuthorizationResponse:
"""OAuth2 device authorization endpoint.
This endpoint implements the OAuth2 device authorization grant flow as
defined in https://tools.ietf.org/html/rfc8628. It is called to initiate
the device authorization flow by requesting a device and user code for a
given client ID.
For a new client ID, a new OAuth device is created, stored in the DB and
returned to the client along with a pair of newly generated device and user
codes. If a device for the given client ID already exists, the existing
DB entry is reused and new device and user codes are generated.
Args:
request: The request object.
client_id: The client ID.
Returns:
The device authorization response.
"""
config = server_config()
store = zen_store()
try:
# Use this opportunity to delete expired devices
store.delete_expired_authorized_devices()
except Exception:
logger.exception("Failed to delete expired devices")
# Fetch additional details about the client from the user-agent header
user_agent_header = request.headers.get("User-Agent")
if user_agent_header:
device_details = OAuthDeviceUserAgentHeader.decode(user_agent_header)
else:
device_details = OAuthDeviceUserAgentHeader()
# Fetch the IP address of the client
ip_address: str = ""
city, region, country = "", "", ""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
ip_address = forwarded.split(",")[0].strip()
elif request.client and request.client.host:
ip_address = request.client.host
if ip_address:
city, region, country = get_ip_location(ip_address)
# Check if a device is already registered for the same client ID.
try:
device_model = store.get_internal_authorized_device(
client_id=client_id
)
except KeyError:
device_model = store.create_authorized_device(
OAuthDeviceInternalRequest(
client_id=client_id,
expires_in=config.device_auth_timeout,
ip_address=ip_address,
city=city,
region=region,
country=country,
**device_details.model_dump(exclude_none=True),
)
)
else:
# Put the device into pending state and generate new codes. This
# effectively invalidates the old codes and the device cannot be used
# for authentication anymore.
device_model = store.update_internal_authorized_device(
device_id=device_model.id,
update=OAuthDeviceInternalUpdate(
trusted_device=False,
expires_in=config.device_auth_timeout,
status=OAuthDeviceStatus.PENDING,
failed_auth_attempts=0,
generate_new_codes=True,
ip_address=ip_address,
city=city,
region=region,
country=country,
**device_details.model_dump(exclude_none=True),
),
)
dashboard_url = config.dashboard_url or config.server_url
if dashboard_url:
verification_uri = dashboard_url.lstrip("/") + DEVICES + DEVICE_VERIFY
else:
verification_uri = DEVICES + DEVICE_VERIFY
verification_uri_complete = (
verification_uri
+ "?"
+ urlencode(
dict(
device_id=str(device_model.id),
user_code=str(device_model.user_code),
)
)
)
return OAuthDeviceAuthorizationResponse(
device_code=device_model.device_code,
user_code=device_model.user_code,
expires_in=config.device_auth_timeout,
interval=config.device_auth_polling_interval,
verification_uri=verification_uri,
verification_uri_complete=verification_uri_complete,
)
logout(response)
Logs out the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
response |
Response |
The response object. |
required |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.get(
LOGOUT,
)
def logout(
response: Response,
) -> None:
"""Logs out the user.
Args:
response: The response object.
"""
config = server_config()
# Remove the HTTP only cookie even if it does not exist
response.delete_cookie(
key=config.get_auth_cookie_name(),
httponly=True,
samesite="lax",
domain=config.auth_cookie_domain,
)
token(request, response, auth_form_data=Depends(OAuthLoginRequestForm))
OAuth2 token endpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request object. |
required |
response |
Response |
The response object. |
required |
auth_form_data |
OAuthLoginRequestForm |
The OAuth 2.0 authentication form data. |
Depends(OAuthLoginRequestForm) |
Returns:
Type | Description |
---|---|
Union[zenml.models.v2.misc.auth_models.OAuthTokenResponse, zenml.models.v2.misc.auth_models.OAuthRedirectResponse] |
An access token or a redirect response. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the grant type is invalid. |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.post(
LOGIN,
response_model=Union[OAuthTokenResponse, OAuthRedirectResponse],
)
@rate_limit_requests(
day_limit=server_config().login_rate_limit_day,
minute_limit=server_config().login_rate_limit_minute,
)
@handle_exceptions
def token(
request: Request,
response: Response,
auth_form_data: OAuthLoginRequestForm = Depends(),
) -> Union[OAuthTokenResponse, OAuthRedirectResponse]:
"""OAuth2 token endpoint.
Args:
request: The request object.
response: The response object.
auth_form_data: The OAuth 2.0 authentication form data.
Returns:
An access token or a redirect response.
Raises:
ValueError: If the grant type is invalid.
"""
config = server_config()
if auth_form_data.grant_type == OAuthGrantTypes.OAUTH_PASSWORD:
auth_context = authenticate_credentials(
user_name_or_id=auth_form_data.username,
password=auth_form_data.password,
)
elif auth_form_data.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE:
auth_context = authenticate_device(
client_id=auth_form_data.client_id,
device_code=auth_form_data.device_code,
)
elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_API_KEY:
auth_context = authenticate_api_key(
api_key=auth_form_data.api_key,
)
elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_EXTERNAL:
assert config.external_cookie_name is not None
assert config.external_login_url is not None
authorization_url = config.external_login_url
# First, try to get the external access token from the external cookie
external_access_token = request.cookies.get(
config.external_cookie_name
)
if not external_access_token:
# Next, try to get the external access token from the authorization
# header
authorization_header = request.headers.get("Authorization")
if authorization_header:
scheme, _, token = authorization_header.partition(" ")
if token and scheme.lower() == "bearer":
external_access_token = token
logger.info(
"External access token found in authorization header."
)
else:
logger.info("External access token found in cookie.")
if not external_access_token:
logger.info(
"External access token not found. Redirecting to "
"external authenticator."
)
# Redirect the user to the external authentication login endpoint
return OAuthRedirectResponse(authorization_url=authorization_url)
auth_context = authenticate_external_user(
external_access_token=external_access_token,
request=request,
)
else:
# Shouldn't happen, because we verify all grants in the form data
raise ValueError("Invalid grant type.")
return generate_access_token(
user_id=auth_context.user.id,
response=response,
device=auth_context.device,
api_key=auth_context.api_key,
)
code_repositories_endpoints
Endpoint definitions for code repositories.
delete_code_repository(code_repository_id, _=Security(oauth2_authentication))
Deletes a specific code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository to delete. |
required |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.delete(
"/{code_repository_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_code_repository(
code_repository_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific code repository.
Args:
code_repository_id: The ID of the code repository to delete.
"""
verify_permissions_and_delete_entity(
id=code_repository_id,
get_method=zen_store().get_code_repository,
delete_method=zen_store().delete_code_repository,
)
get_code_repository(code_repository_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific code repository using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
CodeRepositoryResponse |
A specific code repository object. |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.get(
"/{code_repository_id}",
response_model=CodeRepositoryResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_code_repository(
code_repository_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> CodeRepositoryResponse:
"""Gets a specific code repository using its unique ID.
Args:
code_repository_id: The ID of the code repository to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific code repository object.
"""
return verify_permissions_and_get_entity(
id=code_repository_id,
get_method=zen_store().get_code_repository,
hydrate=hydrate,
)
list_code_repositories(filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a page of code repositories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
CodeRepositoryFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[CodeRepositoryResponse] |
Page of code repository objects. |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.get(
"",
response_model=Page[CodeRepositoryResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_code_repositories(
filter_model: CodeRepositoryFilter = Depends(
make_dependable(CodeRepositoryFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[CodeRepositoryResponse]:
"""Gets a page of code repositories.
Args:
filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page of code repository objects.
"""
return verify_permissions_and_list_entities(
filter_model=filter_model,
resource_type=ResourceType.CODE_REPOSITORY,
list_method=zen_store().list_code_repositories,
hydrate=hydrate,
)
update_code_repository(code_repository_id, update, _=Security(oauth2_authentication))
Updates a code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository to update. |
required |
update |
CodeRepositoryUpdate |
The model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
CodeRepositoryResponse |
The updated code repository object. |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.put(
"/{code_repository_id}",
response_model=CodeRepositoryResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_code_repository(
code_repository_id: UUID,
update: CodeRepositoryUpdate,
_: AuthContext = Security(authorize),
) -> CodeRepositoryResponse:
"""Updates a code repository.
Args:
code_repository_id: The ID of the code repository to update.
update: The model containing the attributes to update.
Returns:
The updated code repository object.
"""
return verify_permissions_and_update_entity(
id=code_repository_id,
update_model=update,
get_method=zen_store().get_code_repository,
update_method=zen_store().update_code_repository,
)
devices_endpoints
Endpoint definitions for code repositories.
delete_authorized_device(device_id, auth_context=Security(oauth2_authentication))
Deletes a specific OAuth2 authorized device using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to delete. |
required |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Exceptions:
Type | Description |
---|---|
KeyError |
If the device with the given ID does not exist or does not belong to the current user. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.delete(
"/{device_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_authorized_device(
device_id: UUID,
auth_context: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific OAuth2 authorized device using its unique ID.
Args:
device_id: The ID of the OAuth2 authorized device to delete.
auth_context: The current auth context.
Raises:
KeyError: If the device with the given ID does not exist or does not
belong to the current user.
"""
device = zen_store().get_authorized_device(device_id=device_id)
if not device.user or device.user.id != auth_context.user.id:
raise KeyError(
f"Unable to get device with ID {device_id}: No device with "
"this ID found."
)
zen_store().delete_authorized_device(device_id=device_id)
get_authorization_device(device_id, user_code=None, hydrate=True, auth_context=Security(oauth2_authentication))
Gets a specific OAuth2 authorized device using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to get. |
required |
user_code |
Optional[str] |
The user code of the OAuth2 authorized device to get. Needs to be specified with devices that have not been verified yet. |
None |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
OAuthDeviceResponse |
A specific OAuth2 authorized device object. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the device with the given ID does not exist, does not belong to the current user or could not be verified using the given user code. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.get(
"/{device_id}",
response_model=OAuthDeviceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_authorization_device(
device_id: UUID,
user_code: Optional[str] = None,
hydrate: bool = True,
auth_context: AuthContext = Security(authorize),
) -> OAuthDeviceResponse:
"""Gets a specific OAuth2 authorized device using its unique ID.
Args:
device_id: The ID of the OAuth2 authorized device to get.
user_code: The user code of the OAuth2 authorized device to get. Needs
to be specified with devices that have not been verified yet.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: The current auth context.
Returns:
A specific OAuth2 authorized device object.
Raises:
KeyError: If the device with the given ID does not exist, does not
belong to the current user or could not be verified using the
given user code.
"""
device = zen_store().get_authorized_device(
device_id=device_id, hydrate=hydrate
)
if not device.user:
# A device that hasn't been verified and associated with a user yet
# can only be retrieved if the user code is specified and valid.
if user_code:
internal_device = zen_store().get_internal_authorized_device(
device_id=device_id, hydrate=hydrate
)
if internal_device.verify_user_code(user_code=user_code):
return device
elif device.user.id == auth_context.user.id:
return device
raise KeyError(
f"Unable to get device with ID {device_id}: No device with "
"this ID found."
)
list_authorized_devices(filter_model=Depends(init_cls_and_handle_errors), hydrate=False, auth_context=Security(oauth2_authentication))
Gets a page of OAuth2 authorized devices belonging to the current user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
OAuthDeviceFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[OAuthDeviceResponse] |
Page of OAuth2 authorized device objects. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.get(
"",
response_model=Page[OAuthDeviceResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_authorized_devices(
filter_model: OAuthDeviceFilter = Depends(
make_dependable(OAuthDeviceFilter)
),
hydrate: bool = False,
auth_context: AuthContext = Security(authorize),
) -> Page[OAuthDeviceResponse]:
"""Gets a page of OAuth2 authorized devices belonging to the current user.
Args:
filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: The current auth context.
Returns:
Page of OAuth2 authorized device objects.
"""
filter_model.set_scope_user(auth_context.user.id)
return zen_store().list_authorized_devices(
filter_model=filter_model, hydrate=hydrate
)
update_authorized_device(device_id, update, auth_context=Security(oauth2_authentication))
Updates a specific OAuth2 authorized device using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to update. |
required |
update |
OAuthDeviceUpdate |
The model containing the attributes to update. |
required |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
OAuthDeviceResponse |
The updated OAuth2 authorized device object. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the device with the given ID does not exist or does not belong to the current user. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.put(
"/{device_id}",
response_model=OAuthDeviceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_authorized_device(
device_id: UUID,
update: OAuthDeviceUpdate,
auth_context: AuthContext = Security(authorize),
) -> OAuthDeviceResponse:
"""Updates a specific OAuth2 authorized device using its unique ID.
Args:
device_id: The ID of the OAuth2 authorized device to update.
update: The model containing the attributes to update.
auth_context: The current auth context.
Returns:
The updated OAuth2 authorized device object.
Raises:
KeyError: If the device with the given ID does not exist or does not
belong to the current user.
"""
device = zen_store().get_authorized_device(device_id=device_id)
if not device.user or device.user.id != auth_context.user.id:
raise KeyError(
f"Unable to get device with ID {device_id}: No device with "
"this ID found."
)
return zen_store().update_authorized_device(
device_id=device_id, update=update
)
verify_authorized_device(device_id, request, auth_context=Security(oauth2_authentication))
Verifies a specific OAuth2 authorized device using its unique ID.
This endpoint implements the OAuth2 device authorization grant flow as defined in https://tools.ietf.org/html/rfc8628. It is called to verify the user code for a given device ID.
If the user code is valid, the device is marked as verified and associated with the user that authorized the device. This association is required to be able to issue access tokens or revoke the device later on.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to update. |
required |
request |
OAuthDeviceVerificationRequest |
The model containing the verification request. |
required |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
OAuthDeviceResponse |
The updated OAuth2 authorized device object. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the device verification request fails. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.put(
"/{device_id}" + DEVICE_VERIFY,
response_model=OAuthDeviceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def verify_authorized_device(
device_id: UUID,
request: OAuthDeviceVerificationRequest,
auth_context: AuthContext = Security(authorize),
) -> OAuthDeviceResponse:
"""Verifies a specific OAuth2 authorized device using its unique ID.
This endpoint implements the OAuth2 device authorization grant flow as
defined in https://tools.ietf.org/html/rfc8628. It is called to verify
the user code for a given device ID.
If the user code is valid, the device is marked as verified and associated
with the user that authorized the device. This association is required to
be able to issue access tokens or revoke the device later on.
Args:
device_id: The ID of the OAuth2 authorized device to update.
request: The model containing the verification request.
auth_context: The current auth context.
Returns:
The updated OAuth2 authorized device object.
Raises:
ValueError: If the device verification request fails.
"""
config = server_config()
store = zen_store()
# Check if a device is registered for the ID
device_model = store.get_internal_authorized_device(
device_id=device_id,
)
with track_handler(event=AnalyticsEvent.DEVICE_VERIFIED):
# Check if the device is in a state that allows verification.
if device_model.status != OAuthDeviceStatus.PENDING:
raise ValueError(
"Invalid request: device not pending verification.",
)
# Check if the device verification has expired.
if device_model.expires and device_model.expires < datetime.utcnow():
raise ValueError(
"Invalid request: device verification expired.",
)
# Check if the device already has a user associated with it. If so, the
# current user and the user associated with the device must be the same.
if device_model.user and device_model.user.id != auth_context.user.id:
raise ValueError(
"Invalid request: this device is associated with another user.",
)
# Check if the device code is valid.
if not device_model.verify_user_code(request.user_code):
# If the device code is invalid, increment the failed auth attempts
# counter and lock the device if the maximum number of failed auth
# attempts has been reached.
failed_auth_attempts = device_model.failed_auth_attempts + 1
update = OAuthDeviceInternalUpdate(
failed_auth_attempts=failed_auth_attempts
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
update.locked = True
store.update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
raise ValueError(
"Invalid request: device locked due to too many failed "
"authentication attempts.",
)
raise ValueError(
"Invalid request: invalid user code.",
)
# If the device code is valid, associate the device with the current user.
# We don't reset the expiration date yet, because we want to make sure
# that the client has received the access token before we do so, to avoid
# brute force attacks on the device code.
update = OAuthDeviceInternalUpdate(
status=OAuthDeviceStatus.VERIFIED,
user_id=auth_context.user.id,
failed_auth_attempts=0,
trusted_device=request.trusted_device,
)
device_model = store.update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
store.update_onboarding_state(
completed_steps={OnboardingStep.DEVICE_VERIFIED}
)
return device_model
event_source_endpoints
Endpoint definitions for event sources.
create_event_source(event_source, _=Security(oauth2_authentication))
Creates an event source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_source |
EventSourceRequest |
EventSource to register. |
required |
Returns:
Type | Description |
---|---|
EventSourceResponse |
The created event source. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the plugin for an event source is not a valid event source plugin. |
Source code in zenml/zen_server/routers/event_source_endpoints.py
@event_source_router.post(
"",
response_model=EventSourceResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_event_source(
event_source: EventSourceRequest,
_: AuthContext = Security(authorize),
) -> EventSourceResponse:
"""Creates an event source.
Args:
event_source: EventSource to register.
Returns:
The created event source.
Raises:
ValueError: If the plugin for an event source is not a valid event
source plugin.
"""
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event source
# implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event source "
"handler implementation."
)
return verify_permissions_and_create_entity(
request_model=event_source,
resource_type=ResourceType.EVENT_SOURCE,
create_method=event_source_handler.create_event_source,
)
delete_event_source(event_source_id, force=False, _=Security(oauth2_authentication))
Deletes a event_source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_source_id |
UUID |
Name of the event_source. |
required |
force |
bool |
Flag deciding whether to force delete the event source. |
False |
Exceptions:
Type | Description |
---|---|
ValueError |
If the plugin for an event source is not a valid event source plugin. |
Source code in zenml/zen_server/routers/event_source_endpoints.py
@event_source_router.delete(
"/{event_source_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_event_source(
event_source_id: UUID,
force: bool = False,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a event_source.
Args:
event_source_id: Name of the event_source.
force: Flag deciding whether to force delete the event source.
Raises:
ValueError: If the plugin for an event source is not a valid event
source plugin.
"""
event_source = zen_store().get_event_source(
event_source_id=event_source_id
)
verify_permission_for_model(event_source, action=Action.DELETE)
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event source
# implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event source "
"handler implementation."
)
event_source_handler.delete_event_source(
event_source=event_source,
force=force,
)
get_event_source(event_source_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested event_source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_source_id |
UUID |
ID of the event_source. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
EventSourceResponse |
The requested event_source. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the plugin for an event source is not a valid event source plugin. |
Source code in zenml/zen_server/routers/event_source_endpoints.py
@event_source_router.get(
"/{event_source_id}",
response_model=EventSourceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_event_source(
event_source_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> EventSourceResponse:
"""Returns the requested event_source.
Args:
event_source_id: ID of the event_source.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested event_source.
Raises:
ValueError: If the plugin for an event source is not a valid event
source plugin.
"""
event_source = zen_store().get_event_source(
event_source_id=event_source_id, hydrate=hydrate
)
verify_permission_for_model(event_source, action=Action.READ)
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event source
# implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event source "
"handler implementation."
)
event_source = event_source_handler.get_event_source(
event_source, hydrate=hydrate
)
return dehydrate_response_model(event_source)
list_event_sources(event_source_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Returns all event_sources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_source_filter_model |
EventSourceFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[EventSourceResponse] |
All event_sources. |
Source code in zenml/zen_server/routers/event_source_endpoints.py
@event_source_router.get(
"",
response_model=Page[EventSourceResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_event_sources(
event_source_filter_model: EventSourceFilter = Depends(
make_dependable(EventSourceFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[EventSourceResponse]:
"""Returns all event_sources.
Args:
event_source_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All event_sources.
"""
def list_event_sources_fn(
filter_model: EventSourceFilter,
) -> Page[EventSourceResponse]:
"""List event sources through their associated plugins.
Args:
filter_model: Filter model used for pagination, sorting,
filtering.
Returns:
All event sources.
Raises:
ValueError: If the plugin for an event source is not a valid event
source plugin.
"""
event_sources = zen_store().list_event_sources(
event_source_filter_model=filter_model, hydrate=hydrate
)
# Process the event sources through their associated plugins
for idx, event_source in enumerate(event_sources.items):
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event
# source implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event "
"source handler implementation."
)
event_sources.items[idx] = event_source_handler.get_event_source(
event_source, hydrate=hydrate
)
return event_sources
return verify_permissions_and_list_entities(
filter_model=event_source_filter_model,
resource_type=ResourceType.EVENT_SOURCE,
list_method=list_event_sources_fn,
)
update_event_source(event_source_id, event_source_update, _=Security(oauth2_authentication))
Updates an event_source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_source_id |
UUID |
Name of the event_source. |
required |
event_source_update |
EventSourceUpdate |
EventSource to use for the update. |
required |
Returns:
Type | Description |
---|---|
EventSourceResponse |
The updated event_source. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the plugin for an event source is not a valid event source plugin. |
Source code in zenml/zen_server/routers/event_source_endpoints.py
@event_source_router.put(
"/{event_source_id}",
response_model=EventSourceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_event_source(
event_source_id: UUID,
event_source_update: EventSourceUpdate,
_: AuthContext = Security(authorize),
) -> EventSourceResponse:
"""Updates an event_source.
Args:
event_source_id: Name of the event_source.
event_source_update: EventSource to use for the update.
Returns:
The updated event_source.
Raises:
ValueError: If the plugin for an event source is not a valid event
source plugin.
"""
event_source = zen_store().get_event_source(
event_source_id=event_source_id
)
verify_permission_for_model(event_source, action=Action.UPDATE)
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event source
# implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event source "
"handler implementation."
)
updated_event_source = event_source_handler.update_event_source(
event_source=event_source,
event_source_update=event_source_update,
)
return dehydrate_response_model(updated_event_source)
flavors_endpoints
Endpoint definitions for flavors.
create_flavor(flavor, _=Security(oauth2_authentication))
Creates a stack component flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
FlavorRequest |
Stack component flavor to register. |
required |
Returns:
Type | Description |
---|---|
FlavorResponse |
The created stack component flavor. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.post(
"",
response_model=FlavorResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_flavor(
flavor: FlavorRequest,
_: AuthContext = Security(authorize),
) -> FlavorResponse:
"""Creates a stack component flavor.
Args:
flavor: Stack component flavor to register.
Returns:
The created stack component flavor.
"""
return verify_permissions_and_create_entity(
request_model=flavor,
resource_type=ResourceType.FLAVOR,
create_method=zen_store().create_flavor,
)
delete_flavor(flavor_id, _=Security(oauth2_authentication))
Deletes a flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_id |
UUID |
ID of the flavor. |
required |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.delete(
"/{flavor_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_flavor(
flavor_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a flavor.
Args:
flavor_id: ID of the flavor.
"""
verify_permissions_and_delete_entity(
id=flavor_id,
get_method=zen_store().get_flavor,
delete_method=zen_store().delete_flavor,
)
get_flavor(flavor_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_id |
UUID |
ID of the flavor. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
FlavorResponse |
The requested stack. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.get(
"/{flavor_id}",
response_model=FlavorResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_flavor(
flavor_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> FlavorResponse:
"""Returns the requested flavor.
Args:
flavor_id: ID of the flavor.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested stack.
"""
return verify_permissions_and_get_entity(
id=flavor_id, get_method=zen_store().get_flavor, hydrate=hydrate
)
list_flavors(flavor_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Returns all flavors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_filter_model |
FlavorFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[FlavorResponse] |
All flavors. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.get(
"",
response_model=Page[FlavorResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_flavors(
flavor_filter_model: FlavorFilter = Depends(make_dependable(FlavorFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[FlavorResponse]:
"""Returns all flavors.
Args:
flavor_filter_model: Filter model used for pagination, sorting,
filtering
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All flavors.
"""
return verify_permissions_and_list_entities(
filter_model=flavor_filter_model,
resource_type=ResourceType.FLAVOR,
list_method=zen_store().list_flavors,
hydrate=hydrate,
)
sync_flavors(_=Security(oauth2_authentication))
Purge all in-built and integration flavors from the DB and sync.
Returns:
Type | Description |
---|---|
None |
None if successful. Raises an exception otherwise. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.patch(
"/sync",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def sync_flavors(
_: AuthContext = Security(authorize),
) -> None:
"""Purge all in-built and integration flavors from the DB and sync.
Returns:
None if successful. Raises an exception otherwise.
"""
verify_permission(resource_type=ResourceType.FLAVOR, action=Action.UPDATE)
return zen_store()._sync_flavors()
update_flavor(flavor_id, flavor_update, _=Security(oauth2_authentication))
Updates a flavor.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_id |
UUID |
ID of the flavor to update. |
required |
flavor_update |
FlavorUpdate |
Flavor update. |
required |
Returns:
Type | Description |
---|---|
FlavorResponse |
The updated flavor. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.put(
"/{flavor_id}",
response_model=FlavorResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def update_flavor(
flavor_id: UUID,
flavor_update: FlavorUpdate,
_: AuthContext = Security(authorize),
) -> FlavorResponse:
"""Updates a flavor.
# noqa: DAR401
Args:
flavor_id: ID of the flavor to update.
flavor_update: Flavor update.
Returns:
The updated flavor.
"""
return verify_permissions_and_update_entity(
id=flavor_id,
update_model=flavor_update,
get_method=zen_store().get_flavor,
update_method=zen_store().update_flavor,
)
logs_endpoints
Endpoint definitions for logs.
get_logs(logs_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested logs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logs_id |
UUID |
ID of the logs. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
LogsResponse |
The requested logs. |
Source code in zenml/zen_server/routers/logs_endpoints.py
@router.get(
"/{logs_id}",
response_model=LogsResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_logs(
logs_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> LogsResponse:
"""Returns the requested logs.
Args:
logs_id: ID of the logs.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested logs.
"""
return verify_permissions_and_get_entity(
id=logs_id, get_method=zen_store().get_logs, hydrate=hydrate
)
model_versions_endpoints
Endpoint definitions for models.
create_model_version_artifact_link(model_version_artifact_link, _=Security(oauth2_authentication))
Create a new model version to artifact link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_artifact_link |
ModelVersionArtifactRequest |
The model version to artifact link to create. |
required |
Returns:
Type | Description |
---|---|
ModelVersionArtifactResponse |
The created model version to artifact link. |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@model_version_artifacts_router.post(
"",
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_artifact_link(
model_version_artifact_link: ModelVersionArtifactRequest,
_: AuthContext = Security(authorize),
) -> ModelVersionArtifactResponse:
"""Create a new model version to artifact link.
Args:
model_version_artifact_link: The model version to artifact link to create.
Returns:
The created model version to artifact link.
"""
model_version = zen_store().get_model_version(
model_version_artifact_link.model_version
)
verify_permission_for_model(model_version, action=Action.UPDATE)
mv = zen_store().create_model_version_artifact_link(
model_version_artifact_link
)
return mv
create_model_version_pipeline_run_link(model_version_pipeline_run_link, _=Security(oauth2_authentication))
Create a new model version to pipeline run link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_pipeline_run_link |
ModelVersionPipelineRunRequest |
The model version to pipeline run link to create. |
required |
Returns:
Type | Description |
---|---|
ModelVersionPipelineRunResponse |
|
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@model_version_pipeline_runs_router.post(
"",
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_pipeline_run_link(
model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
_: AuthContext = Security(authorize),
) -> ModelVersionPipelineRunResponse:
"""Create a new model version to pipeline run link.
Args:
model_version_pipeline_run_link: The model version to pipeline run link to create.
Returns:
- If Model Version to Pipeline Run Link already exists - returns the existing link.
- Otherwise, returns the newly created model version to pipeline run link.
"""
model_version = zen_store().get_model_version(
model_version_pipeline_run_link.model_version, hydrate=False
)
verify_permission_for_model(model_version, action=Action.UPDATE)
mv = zen_store().create_model_version_pipeline_run_link(
model_version_pipeline_run_link
)
return mv
delete_all_model_version_artifact_links(model_version_id, only_links=True, _=Security(oauth2_authentication))
Deletes all model version to artifact links.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
ID of the model version containing links. |
required |
only_links |
bool |
Whether to only delete the link to the artifact. |
True |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.delete(
"/{model_version_id}" + ARTIFACTS,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_all_model_version_artifact_links(
model_version_id: UUID,
only_links: bool = True,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes all model version to artifact links.
Args:
model_version_id: ID of the model version containing links.
only_links: Whether to only delete the link to the artifact.
"""
model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.UPDATE)
zen_store().delete_all_model_version_artifact_links(
model_version_id, only_links
)
delete_model_version(model_version_id, _=Security(oauth2_authentication))
Delete a model by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
The name or ID of the model version to delete. |
required |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.delete(
"/{model_version_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model_version(
model_version_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Delete a model by name or ID.
Args:
model_version_id: The name or ID of the model version to delete.
"""
model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.DELETE)
zen_store().delete_model_version(model_version_id)
delete_model_version_artifact_link(model_version_id, model_version_artifact_link_name_or_id, _=Security(oauth2_authentication))
Deletes a model version to artifact link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
ID of the model version containing the link. |
required |
model_version_artifact_link_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model version to artifact link to be deleted. |
required |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.delete(
"/{model_version_id}"
+ ARTIFACTS
+ "/{model_version_artifact_link_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model_version_artifact_link(
model_version_id: UUID,
model_version_artifact_link_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a model version to artifact link.
Args:
model_version_id: ID of the model version containing the link.
model_version_artifact_link_name_or_id: name or ID of the model
version to artifact link to be deleted.
"""
model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.UPDATE)
zen_store().delete_model_version_artifact_link(
model_version_id,
model_version_artifact_link_name_or_id,
)
delete_model_version_pipeline_run_link(model_version_id, model_version_pipeline_run_link_name_or_id, _=Security(oauth2_authentication))
Deletes a model version link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
name or ID of the model version containing the link. |
required |
model_version_pipeline_run_link_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model version link to be deleted. |
required |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.delete(
"/{model_version_id}"
+ RUNS
+ "/{model_version_pipeline_run_link_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model_version_pipeline_run_link(
model_version_id: UUID,
model_version_pipeline_run_link_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a model version link.
Args:
model_version_id: name or ID of the model version containing the link.
model_version_pipeline_run_link_name_or_id: name or ID of the model
version link to be deleted.
"""
model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.UPDATE)
zen_store().delete_model_version_pipeline_run_link(
model_version_id=model_version_id,
model_version_pipeline_run_link_name_or_id=model_version_pipeline_run_link_name_or_id,
)
get_model_version(model_version_id, hydrate=True, _=Security(oauth2_authentication))
Get a model version by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
id of the model version to be retrieved. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ModelVersionResponse |
The model version with the given name or ID. |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.get(
"/{model_version_id}",
response_model=ModelVersionResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_model_version(
model_version_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ModelVersionResponse:
"""Get a model version by ID.
Args:
model_version_id: id of the model version to be retrieved.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model version with the given name or ID.
"""
model_version = zen_store().get_model_version(
model_version_id=model_version_id,
hydrate=hydrate,
)
verify_permission_for_model(model_version.model, action=Action.READ)
return dehydrate_response_model(model_version)
list_model_version_artifact_links(model_version_artifact_link_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get model version to artifact links according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_artifact_link_filter_model |
ModelVersionArtifactFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ModelVersionArtifactResponse] |
The model version to artifact links according to query filters. |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@model_version_artifacts_router.get(
"",
response_model=Page[ModelVersionArtifactResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_version_artifact_links(
model_version_artifact_link_filter_model: ModelVersionArtifactFilter = Depends(
make_dependable(ModelVersionArtifactFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ModelVersionArtifactResponse]:
"""Get model version to artifact links according to query filters.
Args:
model_version_artifact_link_filter_model: Filter model used for
pagination, sorting, filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model version to artifact links according to query filters.
"""
return zen_store().list_model_version_artifact_links(
model_version_artifact_link_filter_model=model_version_artifact_link_filter_model,
hydrate=hydrate,
)
list_model_version_pipeline_run_links(model_version_pipeline_run_link_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get model version to pipeline run links according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_pipeline_run_link_filter_model |
ModelVersionPipelineRunFilter |
Filter model used for pagination, sorting, and filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ModelVersionPipelineRunResponse] |
The model version to pipeline run links according to query filters. |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@model_version_pipeline_runs_router.get(
"",
response_model=Page[ModelVersionPipelineRunResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_version_pipeline_run_links(
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilter = Depends(
make_dependable(ModelVersionPipelineRunFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ModelVersionPipelineRunResponse]:
"""Get model version to pipeline run links according to query filters.
Args:
model_version_pipeline_run_link_filter_model: Filter model used for
pagination, sorting, and filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model version to pipeline run links according to query filters.
"""
return zen_store().list_model_version_pipeline_run_links(
model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model,
hydrate=hydrate,
)
list_model_versions(model_version_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, auth_context=Security(oauth2_authentication))
Get model versions according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_filter_model |
ModelVersionFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ModelVersionResponse] |
The model versions according to query filters. |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.get(
"",
response_model=Page[ModelVersionResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_versions(
model_version_filter_model: ModelVersionFilter = Depends(
make_dependable(ModelVersionFilter)
),
hydrate: bool = False,
auth_context: AuthContext = Security(authorize),
) -> Page[ModelVersionResponse]:
"""Get model versions according to query filters.
Args:
model_version_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: The authentication context.
Returns:
The model versions according to query filters.
"""
allowed_model_ids = get_allowed_resource_ids(
resource_type=ResourceType.MODEL
)
model_version_filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id,
model_id=allowed_model_ids,
)
model_versions = zen_store().list_model_versions(
model_version_filter_model=model_version_filter_model,
hydrate=hydrate,
)
return dehydrate_page(model_versions)
update_model_version(model_version_id, model_version_update_model, _=Security(oauth2_authentication))
Get all model versions by filter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
The ID of model version to be updated. |
required |
model_version_update_model |
ModelVersionUpdate |
The model version to be updated. |
required |
Returns:
Type | Description |
---|---|
ModelVersionResponse |
An updated model version. |
Source code in zenml/zen_server/routers/model_versions_endpoints.py
@router.put(
"/{model_version_id}",
response_model=ModelVersionResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_model_version(
model_version_id: UUID,
model_version_update_model: ModelVersionUpdate,
_: AuthContext = Security(authorize),
) -> ModelVersionResponse:
"""Get all model versions by filter.
Args:
model_version_id: The ID of model version to be updated.
model_version_update_model: The model version to be updated.
Returns:
An updated model version.
"""
model_version = zen_store().get_model_version(model_version_id)
if model_version_update_model.stage:
# Make sure the user has permissions to promote the model
verify_permission_for_model(model_version.model, action=Action.PROMOTE)
verify_permission_for_model(model_version, action=Action.UPDATE)
updated_model_version = zen_store().update_model_version(
model_version_id=model_version_id,
model_version_update_model=model_version_update_model,
)
return dehydrate_response_model(updated_model_version)
models_endpoints
Endpoint definitions for models.
delete_model(model_name_or_id, _=Security(oauth2_authentication))
Delete a model by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model to delete. |
required |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.delete(
"/{model_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model(
model_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Delete a model by name or ID.
Args:
model_name_or_id: The name or ID of the model to delete.
"""
model = verify_permissions_and_delete_entity(
id=model_name_or_id,
get_method=zen_store().get_model,
delete_method=zen_store().delete_model,
)
if server_config().feature_gate_enabled:
if ResourceType.MODEL in REPORTABLE_RESOURCES:
report_decrement(ResourceType.MODEL, resource_id=model.id)
get_model(model_name_or_id, hydrate=True, _=Security(oauth2_authentication))
Get a model by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ModelResponse |
The model with the given name or ID. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}",
response_model=ModelResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_model(
model_name_or_id: Union[str, UUID],
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ModelResponse:
"""Get a model by name or ID.
Args:
model_name_or_id: The name or ID of the model to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model with the given name or ID.
"""
return verify_permissions_and_get_entity(
id=model_name_or_id, get_method=zen_store().get_model, hydrate=hydrate
)
list_model_versions(model_name_or_id, model_version_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, auth_context=Security(oauth2_authentication))
Get model versions according to query filters.
This endpoint serves the purpose of allowing scoped filtering by model_id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model to list in. |
required |
model_version_filter_model |
ModelVersionFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ModelVersionResponse] |
The model versions according to query filters. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}" + MODEL_VERSIONS,
response_model=Page[ModelVersionResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_versions(
model_name_or_id: Union[str, UUID],
model_version_filter_model: ModelVersionFilter = Depends(
make_dependable(ModelVersionFilter)
),
hydrate: bool = False,
auth_context: AuthContext = Security(authorize),
) -> Page[ModelVersionResponse]:
"""Get model versions according to query filters.
This endpoint serves the purpose of allowing scoped filtering by model_id.
Args:
model_name_or_id: The name or ID of the model to list in.
model_version_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: The authentication context.
Returns:
The model versions according to query filters.
"""
allowed_model_ids = get_allowed_resource_ids(
resource_type=ResourceType.MODEL
)
model_version_filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id, model_id=allowed_model_ids
)
model_versions = zen_store().list_model_versions(
model_name_or_id=model_name_or_id,
model_version_filter_model=model_version_filter_model,
hydrate=hydrate,
)
return dehydrate_page(model_versions)
list_models(model_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get models according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_filter_model |
ModelFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ModelResponse] |
The models according to query filters. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"",
response_model=Page[ModelResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_models(
model_filter_model: ModelFilter = Depends(make_dependable(ModelFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ModelResponse]:
"""Get models according to query filters.
Args:
model_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The models according to query filters.
"""
return verify_permissions_and_list_entities(
filter_model=model_filter_model,
resource_type=ResourceType.MODEL,
list_method=zen_store().list_models,
hydrate=hydrate,
)
update_model(model_id, model_update, _=Security(oauth2_authentication))
Updates a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
UUID |
Name of the stack. |
required |
model_update |
ModelUpdate |
Stack to use for the update. |
required |
Returns:
Type | Description |
---|---|
ModelResponse |
The updated model. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.put(
"/{model_id}",
response_model=ModelResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_model(
model_id: UUID,
model_update: ModelUpdate,
_: AuthContext = Security(authorize),
) -> ModelResponse:
"""Updates a model.
Args:
model_id: Name of the stack.
model_update: Stack to use for the update.
Returns:
The updated model.
"""
return verify_permissions_and_update_entity(
id=model_id,
update_model=model_update,
get_method=zen_store().get_model,
update_method=zen_store().update_model,
)
pipeline_builds_endpoints
Endpoint definitions for builds.
delete_build(build_id, _=Security(oauth2_authentication))
Deletes a specific build.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
build_id |
UUID |
ID of the build to delete. |
required |
Source code in zenml/zen_server/routers/pipeline_builds_endpoints.py
@router.delete(
"/{build_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_build(
build_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific build.
Args:
build_id: ID of the build to delete.
"""
verify_permissions_and_delete_entity(
id=build_id,
get_method=zen_store().get_build,
delete_method=zen_store().delete_build,
)
get_build(build_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific build using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
build_id |
UUID |
ID of the build to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
PipelineBuildResponse |
A specific build object. |
Source code in zenml/zen_server/routers/pipeline_builds_endpoints.py
@router.get(
"/{build_id}",
response_model=PipelineBuildResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_build(
build_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> PipelineBuildResponse:
"""Gets a specific build using its unique id.
Args:
build_id: ID of the build to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific build object.
"""
return verify_permissions_and_get_entity(
id=build_id, get_method=zen_store().get_build, hydrate=hydrate
)
list_builds(build_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a list of builds.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
build_filter_model |
PipelineBuildFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineBuildResponse] |
List of build objects. |
Source code in zenml/zen_server/routers/pipeline_builds_endpoints.py
@router.get(
"",
response_model=Page[PipelineBuildResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_builds(
build_filter_model: PipelineBuildFilter = Depends(
make_dependable(PipelineBuildFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineBuildResponse]:
"""Gets a list of builds.
Args:
build_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of build objects.
"""
return verify_permissions_and_list_entities(
filter_model=build_filter_model,
resource_type=ResourceType.PIPELINE_BUILD,
list_method=zen_store().list_builds,
hydrate=hydrate,
)
pipeline_deployments_endpoints
Endpoint definitions for deployments.
delete_deployment(deployment_id, _=Security(oauth2_authentication))
Deletes a specific deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment_id |
UUID |
ID of the deployment to delete. |
required |
Source code in zenml/zen_server/routers/pipeline_deployments_endpoints.py
@router.delete(
"/{deployment_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_deployment(
deployment_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific deployment.
Args:
deployment_id: ID of the deployment to delete.
"""
verify_permissions_and_delete_entity(
id=deployment_id,
get_method=zen_store().get_deployment,
delete_method=zen_store().delete_deployment,
)
get_deployment(deployment_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific deployment using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment_id |
UUID |
ID of the deployment to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
PipelineDeploymentResponse |
A specific deployment object. |
Source code in zenml/zen_server/routers/pipeline_deployments_endpoints.py
@router.get(
"/{deployment_id}",
response_model=PipelineDeploymentResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_deployment(
deployment_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> PipelineDeploymentResponse:
"""Gets a specific deployment using its unique id.
Args:
deployment_id: ID of the deployment to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific deployment object.
"""
return verify_permissions_and_get_entity(
id=deployment_id,
get_method=zen_store().get_deployment,
hydrate=hydrate,
)
list_deployments(deployment_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a list of deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment_filter_model |
PipelineDeploymentFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineDeploymentResponse] |
List of deployment objects. |
Source code in zenml/zen_server/routers/pipeline_deployments_endpoints.py
@router.get(
"",
response_model=Page[PipelineDeploymentResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_deployments(
deployment_filter_model: PipelineDeploymentFilter = Depends(
make_dependable(PipelineDeploymentFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineDeploymentResponse]:
"""Gets a list of deployment.
Args:
deployment_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of deployment objects.
"""
return verify_permissions_and_list_entities(
filter_model=deployment_filter_model,
resource_type=ResourceType.PIPELINE_DEPLOYMENT,
list_method=zen_store().list_deployments,
hydrate=hydrate,
)
pipelines_endpoints
Endpoint definitions for pipelines.
delete_pipeline(pipeline_id, _=Security(oauth2_authentication))
Deletes a specific pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to delete. |
required |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.delete(
"/{pipeline_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_pipeline(
pipeline_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific pipeline.
Args:
pipeline_id: ID of the pipeline to delete.
"""
pipeline = verify_permissions_and_delete_entity(
id=pipeline_id,
get_method=zen_store().get_pipeline,
delete_method=zen_store().delete_pipeline,
)
should_decrement = (
ResourceType.PIPELINE in REPORTABLE_RESOURCES
and zen_store().count_pipelines(PipelineFilter(name=pipeline.name))
== 0
)
if should_decrement:
report_decrement(ResourceType.PIPELINE, resource_id=pipeline_id)
get_pipeline(pipeline_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific pipeline using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
PipelineResponse |
A specific pipeline object. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"/{pipeline_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_pipeline(
pipeline_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> PipelineResponse:
"""Gets a specific pipeline using its unique id.
Args:
pipeline_id: ID of the pipeline to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific pipeline object.
"""
return verify_permissions_and_get_entity(
id=pipeline_id, get_method=zen_store().get_pipeline, hydrate=hydrate
)
list_pipeline_runs(pipeline_run_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get pipeline runs according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run_filter_model |
PipelineRunFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineRunResponse] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"/{pipeline_id}" + RUNS,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_pipeline_runs(
pipeline_run_filter_model: PipelineRunFilter = Depends(
make_dependable(PipelineRunFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineRunResponse]:
"""Get pipeline runs according to query filters.
Args:
pipeline_run_filter_model: Filter model used for pagination, sorting,
filtering
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The pipeline runs according to query filters.
"""
return zen_store().list_runs(pipeline_run_filter_model, hydrate=hydrate)
list_pipelines(pipeline_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a list of pipelines.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_filter_model |
PipelineFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineResponse] |
List of pipeline objects. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_pipelines(
pipeline_filter_model: PipelineFilter = Depends(
make_dependable(PipelineFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineResponse]:
"""Gets a list of pipelines.
Args:
pipeline_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of pipeline objects.
"""
return verify_permissions_and_list_entities(
filter_model=pipeline_filter_model,
resource_type=ResourceType.PIPELINE,
list_method=zen_store().list_pipelines,
hydrate=hydrate,
)
update_pipeline(pipeline_id, pipeline_update, _=Security(oauth2_authentication))
Updates the attribute on a specific pipeline using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to get. |
required |
pipeline_update |
PipelineUpdate |
the model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
PipelineResponse |
The updated pipeline object. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.put(
"/{pipeline_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_pipeline(
pipeline_id: UUID,
pipeline_update: PipelineUpdate,
_: AuthContext = Security(authorize),
) -> PipelineResponse:
"""Updates the attribute on a specific pipeline using its unique id.
Args:
pipeline_id: ID of the pipeline to get.
pipeline_update: the model containing the attributes to update.
Returns:
The updated pipeline object.
"""
return verify_permissions_and_update_entity(
id=pipeline_id,
update_model=pipeline_update,
get_method=zen_store().get_pipeline,
update_method=zen_store().update_pipeline,
)
plugin_endpoints
Endpoint definitions for plugin flavors.
get_flavor(name, type=Query(PydanticUndefined), subtype=Query(PydanticUndefined), _=Security(oauth2_authentication))
Returns the requested flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the flavor. |
required |
type |
PluginType |
Type of Plugin |
Query(PydanticUndefined) |
subtype |
PluginSubType |
Subtype of Plugin |
Query(PydanticUndefined) |
Returns:
Type | Description |
---|---|
BasePluginFlavorResponse |
The requested flavor response. |
Source code in zenml/zen_server/routers/plugin_endpoints.py
@plugin_router.get(
"/{name}",
response_model=BasePluginFlavorResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_flavor(
name: str,
type: PluginType = Query(..., alias="type"),
subtype: PluginSubType = Query(..., alias="subtype"),
_: AuthContext = Security(authorize),
) -> BasePluginFlavorResponse: # type: ignore[type-arg]
"""Returns the requested flavor.
Args:
name: Name of the flavor.
type: Type of Plugin
subtype: Subtype of Plugin
Returns:
The requested flavor response.
"""
plugin_flavor = plugin_flavor_registry().get_flavor_class(
name=name, _type=type, subtype=subtype
)
return plugin_flavor.get_flavor_response_model(hydrate=True)
list_flavors(type, subtype, page=1, size=20, hydrate=False, _=Security(oauth2_authentication))
Returns all event flavors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
type |
PluginType |
The type of Plugin |
required |
subtype |
PluginSubType |
The subtype of the plugin |
required |
page |
int |
Page for pagination (offset +1) |
1 |
size |
int |
Page size for pagination |
20 |
hydrate |
bool |
Whether to hydrate the response bodies |
False |
Returns:
Type | Description |
---|---|
Page[BasePluginFlavorResponse] |
A page of flavors. |
Source code in zenml/zen_server/routers/plugin_endpoints.py
@plugin_router.get(
"",
response_model=Page[BasePluginFlavorResponse], # type: ignore[type-arg]
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_flavors(
type: PluginType,
subtype: PluginSubType,
page: int = PAGINATION_STARTING_PAGE,
size: int = PAGE_SIZE_DEFAULT,
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[BasePluginFlavorResponse]: # type: ignore[type-arg]
"""Returns all event flavors.
Args:
type: The type of Plugin
subtype: The subtype of the plugin
page: Page for pagination (offset +1)
size: Page size for pagination
hydrate: Whether to hydrate the response bodies
Returns:
A page of flavors.
"""
flavors = plugin_flavor_registry().list_available_flavor_responses_for_type_and_subtype(
_type=type, subtype=subtype, page=page, size=size, hydrate=hydrate
)
return flavors
run_templates_endpoints
Endpoint definitions for run templates.
delete_run_template(template_id, _=Security(oauth2_authentication))
Delete a run template.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
template_id |
UUID |
ID of the run template to delete. |
required |
Source code in zenml/zen_server/routers/run_templates_endpoints.py
@router.delete(
"/{template_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_run_template(
template_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Delete a run template.
Args:
template_id: ID of the run template to delete.
"""
verify_permissions_and_delete_entity(
id=template_id,
get_method=zen_store().get_run_template,
delete_method=zen_store().delete_run_template,
)
get_run_template(template_id, hydrate=True, _=Security(oauth2_authentication))
Get a run template.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
template_id |
UUID |
ID of the run template to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
RunTemplateResponse |
The run template. |
Source code in zenml/zen_server/routers/run_templates_endpoints.py
@router.get(
"/{template_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run_template(
template_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> RunTemplateResponse:
"""Get a run template.
Args:
template_id: ID of the run template to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The run template.
"""
return verify_permissions_and_get_entity(
id=template_id,
get_method=zen_store().get_run_template,
hydrate=hydrate,
)
list_run_templates(filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get a page of run templates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
RunTemplateFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[RunTemplateResponse] |
Page of run templates. |
Source code in zenml/zen_server/routers/run_templates_endpoints.py
@router.get(
"",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_run_templates(
filter_model: RunTemplateFilter = Depends(
make_dependable(RunTemplateFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[RunTemplateResponse]:
"""Get a page of run templates.
Args:
filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page of run templates.
"""
return verify_permissions_and_list_entities(
filter_model=filter_model,
resource_type=ResourceType.RUN_TEMPLATE,
list_method=zen_store().list_run_templates,
hydrate=hydrate,
)
update_run_template(template_id, update, _=Security(oauth2_authentication))
Update a run template.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
template_id |
UUID |
ID of the run template to get. |
required |
update |
RunTemplateUpdate |
The updates to apply. |
required |
Returns:
Type | Description |
---|---|
RunTemplateResponse |
The updated run template. |
Source code in zenml/zen_server/routers/run_templates_endpoints.py
@router.put(
"/{template_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_run_template(
template_id: UUID,
update: RunTemplateUpdate,
_: AuthContext = Security(authorize),
) -> RunTemplateResponse:
"""Update a run template.
Args:
template_id: ID of the run template to get.
update: The updates to apply.
Returns:
The updated run template.
"""
return verify_permissions_and_update_entity(
id=template_id,
update_model=update,
get_method=zen_store().get_run_template,
update_method=zen_store().update_run_template,
)
runs_endpoints
Endpoint definitions for pipeline runs.
delete_run(run_id, _=Security(oauth2_authentication))
Deletes a run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the run. |
required |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.delete(
"/{run_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_run(
run_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a run.
Args:
run_id: ID of the run.
"""
verify_permissions_and_delete_entity(
id=run_id,
get_method=zen_store().get_run,
delete_method=zen_store().delete_run,
)
get_pipeline_configuration(run_id, _=Security(oauth2_authentication))
Get the pipeline configuration of a specific pipeline run using its ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run to get. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The pipeline configuration of the pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + PIPELINE_CONFIGURATION,
response_model=Dict[str, Any],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_pipeline_configuration(
run_id: UUID,
_: AuthContext = Security(authorize),
) -> Dict[str, Any]:
"""Get the pipeline configuration of a specific pipeline run using its ID.
Args:
run_id: ID of the pipeline run to get.
Returns:
The pipeline configuration of the pipeline run.
"""
run = verify_permissions_and_get_entity(
id=run_id, get_method=zen_store().get_run, hydrate=True
)
return run.config.model_dump()
get_run(run_id, hydrate=True, refresh_status=False, _=Security(oauth2_authentication))
Get a specific pipeline run using its ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
refresh_status |
bool |
Flag deciding whether we should try to refresh the status of the pipeline run using its orchestrator. |
False |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
The pipeline run. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the stack or the orchestrator of the run is deleted. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}",
response_model=PipelineRunResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run(
run_id: UUID,
hydrate: bool = True,
refresh_status: bool = False,
_: AuthContext = Security(authorize),
) -> PipelineRunResponse:
"""Get a specific pipeline run using its ID.
Args:
run_id: ID of the pipeline run to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
refresh_status: Flag deciding whether we should try to refresh
the status of the pipeline run using its orchestrator.
Returns:
The pipeline run.
Raises:
RuntimeError: If the stack or the orchestrator of the run is deleted.
"""
run = verify_permissions_and_get_entity(
id=run_id, get_method=zen_store().get_run, hydrate=hydrate
)
if refresh_status:
try:
# Check the stack and its orchestrator
if run.stack is not None:
orchestrators = run.stack.components.get(
StackComponentType.ORCHESTRATOR, []
)
if orchestrators:
verify_permission_for_model(
model=orchestrators[0], action=Action.READ
)
else:
raise RuntimeError(
f"The orchestrator, the run '{run.id}' was executed "
"with, is deleted."
)
else:
raise RuntimeError(
f"The stack, the run '{run.id}' was executed on, is deleted."
)
run = run.refresh_run_status()
except Exception as e:
logger.warning(
"An error occurred while refreshing the status of the "
f"pipeline run: {e}"
)
return run
get_run_status(run_id, _=Security(oauth2_authentication))
Get the status of a specific pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run for which to get the status. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
The status of the pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + STATUS,
response_model=ExecutionStatus,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run_status(
run_id: UUID,
_: AuthContext = Security(authorize),
) -> ExecutionStatus:
"""Get the status of a specific pipeline run.
Args:
run_id: ID of the pipeline run for which to get the status.
Returns:
The status of the pipeline run.
"""
run = verify_permissions_and_get_entity(
id=run_id, get_method=zen_store().get_run, hydrate=False
)
return run.status
get_run_steps(run_id, step_run_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get all steps for a given pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run. |
required |
step_run_filter_model |
StepRunFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[StepRunResponse] |
The steps for a given pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + STEPS,
response_model=Page[StepRunResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run_steps(
run_id: UUID,
step_run_filter_model: StepRunFilter = Depends(
make_dependable(StepRunFilter)
),
_: AuthContext = Security(authorize),
) -> Page[StepRunResponse]:
"""Get all steps for a given pipeline run.
Args:
run_id: ID of the pipeline run.
step_run_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The steps for a given pipeline run.
"""
verify_permissions_and_get_entity(
id=run_id, get_method=zen_store().get_run, hydrate=False
)
step_run_filter_model.pipeline_run_id = run_id
return zen_store().list_run_steps(step_run_filter_model)
list_runs(runs_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get pipeline runs according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
runs_filter_model |
PipelineRunFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineRunResponse] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"",
response_model=Page[PipelineRunResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_runs(
runs_filter_model: PipelineRunFilter = Depends(
make_dependable(PipelineRunFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineRunResponse]:
"""Get pipeline runs according to query filters.
Args:
runs_filter_model: Filter model used for pagination, sorting, filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The pipeline runs according to query filters.
"""
return verify_permissions_and_list_entities(
filter_model=runs_filter_model,
resource_type=ResourceType.PIPELINE_RUN,
list_method=zen_store().list_runs,
hydrate=hydrate,
)
refresh_run_status(run_id, _=Security(oauth2_authentication))
Refreshes the status of a specific pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run to refresh. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the stack or the orchestrator of the run is deleted. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + REFRESH,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def refresh_run_status(
run_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Refreshes the status of a specific pipeline run.
Args:
run_id: ID of the pipeline run to refresh.
Raises:
RuntimeError: If the stack or the orchestrator of the run is deleted.
"""
# Verify access to the run
run = verify_permissions_and_get_entity(
id=run_id,
get_method=zen_store().get_run,
hydrate=True,
)
# Check the stack and its orchestrator
if run.stack is not None:
orchestrators = run.stack.components.get(
StackComponentType.ORCHESTRATOR, []
)
if orchestrators:
verify_permission_for_model(
model=orchestrators[0], action=Action.READ
)
else:
raise RuntimeError(
f"The orchestrator, the run '{run.id}' was executed with, is "
"deleted."
)
else:
raise RuntimeError(
f"The stack, the run '{run.id}' was executed on, is deleted."
)
run.refresh_run_status()
update_run(run_id, run_model, _=Security(oauth2_authentication))
Updates a run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the run. |
required |
run_model |
PipelineRunUpdate |
Run model to use for the update. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
The updated run model. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.put(
"/{run_id}",
response_model=PipelineRunResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_run(
run_id: UUID,
run_model: PipelineRunUpdate,
_: AuthContext = Security(authorize),
) -> PipelineRunResponse:
"""Updates a run.
Args:
run_id: ID of the run.
run_model: Run model to use for the update.
Returns:
The updated run model.
"""
return verify_permissions_and_update_entity(
id=run_id,
update_model=run_model,
get_method=zen_store().get_run,
update_method=zen_store().update_run,
)
schedule_endpoints
Endpoint definitions for pipeline run schedules.
delete_schedule(schedule_id, _=Security(oauth2_authentication))
Deletes a specific schedule using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_id |
UUID |
ID of the schedule to delete. |
required |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.delete(
"/{schedule_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_schedule(
schedule_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific schedule using its unique id.
Args:
schedule_id: ID of the schedule to delete.
"""
zen_store().delete_schedule(schedule_id=schedule_id)
get_schedule(schedule_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific schedule using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_id |
UUID |
ID of the schedule to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ScheduleResponse |
A specific schedule object. |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.get(
"/{schedule_id}",
response_model=ScheduleResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_schedule(
schedule_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ScheduleResponse:
"""Gets a specific schedule using its unique id.
Args:
schedule_id: ID of the schedule to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific schedule object.
"""
return zen_store().get_schedule(schedule_id=schedule_id, hydrate=hydrate)
list_schedules(schedule_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a list of schedules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_filter_model |
ScheduleFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ScheduleResponse] |
List of schedule objects. |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.get(
"",
response_model=Page[ScheduleResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_schedules(
schedule_filter_model: ScheduleFilter = Depends(
make_dependable(ScheduleFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ScheduleResponse]:
"""Gets a list of schedules.
Args:
schedule_filter_model: Filter model used for pagination, sorting,
filtering
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of schedule objects.
"""
return zen_store().list_schedules(
schedule_filter_model=schedule_filter_model, hydrate=hydrate
)
update_schedule(schedule_id, schedule_update, _=Security(oauth2_authentication))
Updates the attribute on a specific schedule using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_id |
UUID |
ID of the schedule to get. |
required |
schedule_update |
ScheduleUpdate |
the model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
ScheduleResponse |
The updated schedule object. |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.put(
"/{schedule_id}",
response_model=ScheduleResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_schedule(
schedule_id: UUID,
schedule_update: ScheduleUpdate,
_: AuthContext = Security(authorize),
) -> ScheduleResponse:
"""Updates the attribute on a specific schedule using its unique id.
Args:
schedule_id: ID of the schedule to get.
schedule_update: the model containing the attributes to update.
Returns:
The updated schedule object.
"""
return zen_store().update_schedule(
schedule_id=schedule_id, schedule_update=schedule_update
)
secrets_endpoints
Endpoint definitions for pipeline run secrets.
backup_secrets(ignore_errors=True, delete_secrets=False, _=Security(oauth2_authentication))
Backs up all secrets in the secrets store to the backup secrets store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ignore_errors |
bool |
Whether to ignore individual errors when backing up secrets and continue with the backup operation until all secrets have been backed up. |
True |
delete_secrets |
bool |
Whether to delete the secrets that have been successfully backed up from the primary secrets store. Setting this flag effectively moves all secrets from the primary secrets store to the backup secrets store. |
False |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@op_router.put(
SECRETS_BACKUP,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def backup_secrets(
ignore_errors: bool = True,
delete_secrets: bool = False,
_: AuthContext = Security(authorize),
) -> None:
"""Backs up all secrets in the secrets store to the backup secrets store.
Args:
ignore_errors: Whether to ignore individual errors when backing up
secrets and continue with the backup operation until all secrets
have been backed up.
delete_secrets: Whether to delete the secrets that have been
successfully backed up from the primary secrets store. Setting
this flag effectively moves all secrets from the primary secrets
store to the backup secrets store.
"""
verify_permission(
resource_type=ResourceType.SECRET, action=Action.BACKUP_RESTORE
)
zen_store().backup_secrets(
ignore_errors=ignore_errors, delete_secrets=delete_secrets
)
delete_secret(secret_id, _=Security(oauth2_authentication))
Deletes a specific secret using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_id |
UUID |
ID of the secret to delete. |
required |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.delete(
"/{secret_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_secret(
secret_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific secret using its unique id.
Args:
secret_id: ID of the secret to delete.
"""
verify_permissions_and_delete_entity(
id=secret_id,
get_method=zen_store().get_secret,
delete_method=zen_store().delete_secret,
)
get_secret(secret_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific secret using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_id |
UUID |
ID of the secret to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
SecretResponse |
A specific secret object. |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.get(
"/{secret_id}",
response_model=SecretResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_secret(
secret_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> SecretResponse:
"""Gets a specific secret using its unique id.
Args:
secret_id: ID of the secret to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific secret object.
"""
secret = verify_permissions_and_get_entity(
id=secret_id,
get_method=zen_store().get_secret,
hydrate=hydrate,
)
if not has_permissions_for_model(secret, action=Action.READ_SECRET_VALUE):
secret.remove_secrets()
return secret
list_secrets(secret_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a list of secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_filter_model |
SecretFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[SecretResponse] |
List of secret objects. |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.get(
"",
response_model=Page[SecretResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_secrets(
secret_filter_model: SecretFilter = Depends(make_dependable(SecretFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[SecretResponse]:
"""Gets a list of secrets.
Args:
secret_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of secret objects.
"""
secrets = verify_permissions_and_list_entities(
filter_model=secret_filter_model,
resource_type=ResourceType.SECRET,
list_method=zen_store().list_secrets,
hydrate=hydrate,
)
# This will be `None` if the user is allowed to read secret values
# for all secrets
allowed_ids = get_allowed_resource_ids(
resource_type=ResourceType.SECRET,
action=Action.READ_SECRET_VALUE,
)
if allowed_ids is not None:
for secret in secrets.items:
if secret.id in allowed_ids or is_owned_by_authenticated_user(
secret
):
continue
secret.remove_secrets()
return secrets
restore_secrets(ignore_errors=False, delete_secrets=False, _=Security(oauth2_authentication))
Restores all secrets from the backup secrets store into the main secrets store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ignore_errors |
bool |
Whether to ignore individual errors when restoring secrets and continue with the restore operation until all secrets have been restored. |
False |
delete_secrets |
bool |
Whether to delete the secrets that have been successfully restored from the backup secrets store. Setting this flag effectively moves all secrets from the backup secrets store to the primary secrets store. |
False |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@op_router.put(
SECRETS_RESTORE,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def restore_secrets(
ignore_errors: bool = False,
delete_secrets: bool = False,
_: AuthContext = Security(authorize),
) -> None:
"""Restores all secrets from the backup secrets store into the main secrets store.
Args:
ignore_errors: Whether to ignore individual errors when restoring
secrets and continue with the restore operation until all secrets
have been restored.
delete_secrets: Whether to delete the secrets that have been
successfully restored from the backup secrets store. Setting
this flag effectively moves all secrets from the backup secrets
store to the primary secrets store.
"""
verify_permission(
resource_type=ResourceType.SECRET,
action=Action.BACKUP_RESTORE,
)
zen_store().restore_secrets(
ignore_errors=ignore_errors, delete_secrets=delete_secrets
)
update_secret(secret_id, secret_update, patch_values=False, _=Security(oauth2_authentication))
Updates the attribute on a specific secret using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_id |
UUID |
ID of the secret to get. |
required |
secret_update |
SecretUpdate |
the model containing the attributes to update. |
required |
patch_values |
Optional[bool] |
Whether to patch the secret values or replace them. |
False |
Returns:
Type | Description |
---|---|
SecretResponse |
The updated secret object. |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.put(
"/{secret_id}",
response_model=SecretResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_secret(
secret_id: UUID,
secret_update: SecretUpdate,
patch_values: Optional[bool] = False,
_: AuthContext = Security(authorize),
) -> SecretResponse:
"""Updates the attribute on a specific secret using its unique id.
Args:
secret_id: ID of the secret to get.
secret_update: the model containing the attributes to update.
patch_values: Whether to patch the secret values or replace them.
Returns:
The updated secret object.
"""
if not patch_values:
# If patch_values is False, interpret the update values as a complete
# replacement of the existing secret values. The only adjustment we
# need to make is to set the value of any keys that are not present in
# the update to None, so that they are deleted.
secret = zen_store().get_secret(secret_id=secret_id)
for key in secret.values.keys():
if secret_update.values is not None:
if key not in secret_update.values:
secret_update.values[key] = None
return verify_permissions_and_update_entity(
id=secret_id,
update_model=secret_update,
get_method=zen_store().get_secret,
update_method=zen_store().update_secret,
)
server_endpoints
Endpoint definitions for authentication (login).
activate_server(activate_request)
Updates a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activate_request |
ServerActivationRequest |
The request to activate the server. |
required |
Returns:
Type | Description |
---|---|
Optional[zenml.models.v2.core.user.UserResponse] |
The default admin user that was created during activation, if any. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.put(
ACTIVATE,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def activate_server(
activate_request: ServerActivationRequest,
) -> Optional[UserResponse]:
"""Updates a stack.
Args:
activate_request: The request to activate the server.
Returns:
The default admin user that was created during activation, if any.
"""
return zen_store().activate_server(activate_request)
get_onboarding_state(_=Security(oauth2_authentication))
Get the onboarding state of the server.
Returns:
Type | Description |
---|---|
List[str] |
The onboarding state of the server. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get(
ONBOARDING_STATE,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def get_onboarding_state(
_: AuthContext = Security(authorize),
) -> List[str]:
"""Get the onboarding state of the server.
Returns:
The onboarding state of the server.
"""
return zen_store().get_onboarding_state()
get_settings(_=Security(oauth2_authentication), hydrate=True)
Get settings of the server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hydrate |
bool |
Whether to hydrate the response. |
True |
Returns:
Type | Description |
---|---|
ServerSettingsResponse |
Settings of the server. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get(
SERVER_SETTINGS,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def get_settings(
_: AuthContext = Security(authorize),
hydrate: bool = True,
) -> ServerSettingsResponse:
"""Get settings of the server.
Args:
hydrate: Whether to hydrate the response.
Returns:
Settings of the server.
"""
return zen_store().get_server_settings(hydrate=hydrate)
server_info()
Get information about the server.
Returns:
Type | Description |
---|---|
ServerModel |
Information about the server. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get(
INFO,
response_model=ServerModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def server_info() -> ServerModel:
"""Get information about the server.
Returns:
Information about the server.
"""
return zen_store().get_store_info()
server_load_info(_=Security(oauth2_authentication))
Get information about the server load.
Returns:
Type | Description |
---|---|
ServerLoadInfo |
Information about the server load. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get(
LOAD_INFO,
response_model=ServerLoadInfo,
)
@handle_exceptions
def server_load_info(_: AuthContext = Security(authorize)) -> ServerLoadInfo:
"""Get information about the server load.
Returns:
Information about the server load.
"""
import threading
# Get the current number of threads
num_threads = len(threading.enumerate())
store = zen_store()
if store.config.driver == "sqlite":
# SQLite doesn't have a connection pool
return ServerLoadInfo(
threads=num_threads,
db_connections_total=0,
db_connections_active=0,
db_connections_overflow=0,
)
from sqlalchemy.pool import QueuePool
# Get the number of connections
pool = store.engine.pool
assert isinstance(pool, QueuePool)
idle_conn = pool.checkedin()
active_conn = pool.checkedout()
overflow_conn = max(0, pool.overflow())
total_conn = idle_conn + active_conn
return ServerLoadInfo(
threads=num_threads,
db_connections_total=total_conn,
db_connections_active=active_conn,
db_connections_overflow=overflow_conn,
)
update_server_settings(settings_update, auth_context=Security(oauth2_authentication))
Updates the settings of the server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
settings_update |
ServerSettingsUpdate |
Settings update. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If trying to update admin properties without admin permissions. |
Returns:
Type | Description |
---|---|
ServerSettingsResponse |
The updated settings. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.put(
SERVER_SETTINGS,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_server_settings(
settings_update: ServerSettingsUpdate,
auth_context: AuthContext = Security(authorize),
) -> ServerSettingsResponse:
"""Updates the settings of the server.
Args:
settings_update: Settings update.
auth_context: Authentication context.
Raises:
IllegalOperationError: If trying to update admin properties without
admin permissions.
Returns:
The updated settings.
"""
if not server_config().rbac_enabled:
will_update_admin_properties = bool(
settings_update.model_dump(
exclude_none=True, exclude={"onboarding_state"}
)
)
if not auth_context.user.is_admin and will_update_admin_properties:
raise IllegalOperationError(
"Only admins can update server settings."
)
return zen_store().update_server_settings(settings_update)
version()
Get version of the server.
Returns:
Type | Description |
---|---|
str |
String representing the version of the server. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get("/version")
def version() -> str:
"""Get version of the server.
Returns:
String representing the version of the server.
"""
return zenml.__version__
service_accounts_endpoints
Endpoint definitions for API keys.
create_api_key(service_account_id, api_key, _=Security(oauth2_authentication))
Creates an API key for a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_id |
UUID |
ID of the service account for which to create the API key. |
required |
api_key |
APIKeyRequest |
API key to create. |
required |
Returns:
Type | Description |
---|---|
APIKeyResponse |
The created API key. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.post(
"/{service_account_id}" + API_KEYS,
response_model=APIKeyResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_api_key(
service_account_id: UUID,
api_key: APIKeyRequest,
_: AuthContext = Security(authorize),
) -> APIKeyResponse:
"""Creates an API key for a service account.
Args:
service_account_id: ID of the service account for which to create the
API key.
api_key: API key to create.
Returns:
The created API key.
"""
service_account = zen_store().get_service_account(service_account_id)
verify_permission_for_model(service_account, action=Action.UPDATE)
created_api_key = zen_store().create_api_key(
service_account_id=service_account_id,
api_key=api_key,
)
return created_api_key
create_service_account(service_account, _=Security(oauth2_authentication))
Creates a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account |
ServiceAccountRequest |
Service account to create. |
required |
Returns:
Type | Description |
---|---|
ServiceAccountResponse |
The created service account. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.post(
"",
response_model=ServiceAccountResponse,
responses={
401: error_response,
409: error_response,
422: error_response,
},
)
@handle_exceptions
def create_service_account(
service_account: ServiceAccountRequest,
_: AuthContext = Security(authorize),
) -> ServiceAccountResponse:
"""Creates a service account.
Args:
service_account: Service account to create.
Returns:
The created service account.
"""
return verify_permissions_and_create_entity(
request_model=service_account,
resource_type=ResourceType.SERVICE_ACCOUNT,
create_method=zen_store().create_service_account,
)
delete_api_key(service_account_id, api_key_name_or_id, _=Security(oauth2_authentication))
Deletes an API key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_id |
UUID |
ID of the service account to which the API key belongs. |
required |
api_key_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the API key to delete. |
required |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.delete(
"/{service_account_id}" + API_KEYS + "/{api_key_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_api_key(
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Deletes an API key.
Args:
service_account_id: ID of the service account to which the API key
belongs.
api_key_name_or_id: Name or ID of the API key to delete.
"""
service_account = zen_store().get_service_account(service_account_id)
verify_permission_for_model(service_account, action=Action.UPDATE)
zen_store().delete_api_key(
service_account_id=service_account_id,
api_key_name_or_id=api_key_name_or_id,
)
delete_service_account(service_account_name_or_id, _=Security(oauth2_authentication))
Delete a specific service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the service account. |
required |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.delete(
"/{service_account_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_service_account(
service_account_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Delete a specific service account.
Args:
service_account_name_or_id: Name or ID of the service account.
"""
verify_permissions_and_delete_entity(
id=service_account_name_or_id,
get_method=zen_store().get_service_account,
delete_method=zen_store().delete_service_account,
)
get_api_key(service_account_id, api_key_name_or_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested API key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_id |
UUID |
ID of the service account to which the API key belongs. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
api_key_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the API key to return. |
required |
Returns:
Type | Description |
---|---|
APIKeyResponse |
The requested API key. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.get(
"/{service_account_id}" + API_KEYS + "/{api_key_name_or_id}",
response_model=APIKeyResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_api_key(
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> APIKeyResponse:
"""Returns the requested API key.
Args:
service_account_id: ID of the service account to which the API key
belongs.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
api_key_name_or_id: Name or ID of the API key to return.
Returns:
The requested API key.
"""
service_account = zen_store().get_service_account(service_account_id)
verify_permission_for_model(service_account, action=Action.READ)
api_key = zen_store().get_api_key(
service_account_id=service_account_id,
api_key_name_or_id=api_key_name_or_id,
hydrate=hydrate,
)
return api_key
get_service_account(service_account_name_or_id, _=Security(oauth2_authentication), hydrate=True)
Returns a specific service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the service account. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ServiceAccountResponse |
The service account matching the given name or ID. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.get(
"/{service_account_name_or_id}",
response_model=ServiceAccountResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_account(
service_account_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
hydrate: bool = True,
) -> ServiceAccountResponse:
"""Returns a specific service account.
Args:
service_account_name_or_id: Name or ID of the service account.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The service account matching the given name or ID.
"""
return verify_permissions_and_get_entity(
id=service_account_name_or_id,
get_method=zen_store().get_service_account,
hydrate=hydrate,
)
list_api_keys(service_account_id, filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
List API keys associated with a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_id |
UUID |
ID of the service account to which the API keys belong. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
filter_model |
APIKeyFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[APIKeyResponse] |
All API keys matching the filter and associated with the supplied service account. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.get(
"/{service_account_id}" + API_KEYS,
response_model=Page[APIKeyResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_api_keys(
service_account_id: UUID,
filter_model: APIKeyFilter = Depends(make_dependable(APIKeyFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[APIKeyResponse]:
"""List API keys associated with a service account.
Args:
service_account_id: ID of the service account to which the API keys
belong.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
filter_model: Filter model used for pagination, sorting,
filtering
Returns:
All API keys matching the filter and associated with the supplied
service account.
"""
service_account = zen_store().get_service_account(service_account_id)
verify_permission_for_model(service_account, action=Action.READ)
return zen_store().list_api_keys(
service_account_id=service_account_id,
filter_model=filter_model,
hydrate=hydrate,
)
list_service_accounts(filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Returns a list of service accounts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
ServiceAccountFilter |
Model that takes care of filtering, sorting and pagination. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ServiceAccountResponse] |
A list of service accounts matching the filter. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.get(
"",
response_model=Page[ServiceAccountResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_accounts(
filter_model: ServiceAccountFilter = Depends(
make_dependable(ServiceAccountFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ServiceAccountResponse]:
"""Returns a list of service accounts.
Args:
filter_model: Model that takes care of filtering, sorting and
pagination.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of service accounts matching the filter.
"""
return verify_permissions_and_list_entities(
filter_model=filter_model,
resource_type=ResourceType.SERVICE_ACCOUNT,
list_method=zen_store().list_service_accounts,
hydrate=hydrate,
)
rotate_api_key(service_account_id, api_key_name_or_id, rotate_request, _=Security(oauth2_authentication))
Rotate an API key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_id |
UUID |
ID of the service account to which the API key belongs. |
required |
api_key_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the API key to rotate. |
required |
rotate_request |
APIKeyRotateRequest |
API key rotation request. |
required |
Returns:
Type | Description |
---|---|
APIKeyResponse |
The updated API key. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.put(
"/{service_account_id}"
+ API_KEYS
+ "/{api_key_name_or_id}"
+ API_KEY_ROTATE,
response_model=APIKeyResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def rotate_api_key(
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
rotate_request: APIKeyRotateRequest,
_: AuthContext = Security(authorize),
) -> APIKeyResponse:
"""Rotate an API key.
Args:
service_account_id: ID of the service account to which the API key
belongs.
api_key_name_or_id: Name or ID of the API key to rotate.
rotate_request: API key rotation request.
Returns:
The updated API key.
"""
service_account = zen_store().get_service_account(service_account_id)
verify_permission_for_model(service_account, action=Action.UPDATE)
return zen_store().rotate_api_key(
service_account_id=service_account_id,
api_key_name_or_id=api_key_name_or_id,
rotate_request=rotate_request,
)
update_api_key(service_account_id, api_key_name_or_id, api_key_update, _=Security(oauth2_authentication))
Updates an API key for a service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_id |
UUID |
ID of the service account to which the API key belongs. |
required |
api_key_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the API key to update. |
required |
api_key_update |
APIKeyUpdate |
API key update. |
required |
Returns:
Type | Description |
---|---|
APIKeyResponse |
The updated API key. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.put(
"/{service_account_id}" + API_KEYS + "/{api_key_name_or_id}",
response_model=APIKeyResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def update_api_key(
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
api_key_update: APIKeyUpdate,
_: AuthContext = Security(authorize),
) -> APIKeyResponse:
"""Updates an API key for a service account.
Args:
service_account_id: ID of the service account to which the API key
belongs.
api_key_name_or_id: Name or ID of the API key to update.
api_key_update: API key update.
Returns:
The updated API key.
"""
service_account = zen_store().get_service_account(service_account_id)
verify_permission_for_model(service_account, action=Action.UPDATE)
return zen_store().update_api_key(
service_account_id=service_account_id,
api_key_name_or_id=api_key_name_or_id,
api_key_update=api_key_update,
)
update_service_account(service_account_name_or_id, service_account_update, _=Security(oauth2_authentication))
Updates a specific service account.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_account_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the service account. |
required |
service_account_update |
ServiceAccountUpdate |
the service account to use for the update. |
required |
Returns:
Type | Description |
---|---|
ServiceAccountResponse |
The updated service account. |
Source code in zenml/zen_server/routers/service_accounts_endpoints.py
@router.put(
"/{service_account_name_or_id}",
response_model=ServiceAccountResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_service_account(
service_account_name_or_id: Union[str, UUID],
service_account_update: ServiceAccountUpdate,
_: AuthContext = Security(authorize),
) -> ServiceAccountResponse:
"""Updates a specific service account.
Args:
service_account_name_or_id: Name or ID of the service account.
service_account_update: the service account to use for the update.
Returns:
The updated service account.
"""
return verify_permissions_and_update_entity(
id=service_account_name_or_id,
update_model=service_account_update,
get_method=zen_store().get_service_account,
update_method=zen_store().update_service_account,
)
service_connectors_endpoints
Endpoint definitions for service connectors.
delete_service_connector(connector_id, _=Security(oauth2_authentication))
Deletes a service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.delete(
"/{connector_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_service_connector(
connector_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a service connector.
Args:
connector_id: ID of the service connector.
"""
verify_permissions_and_delete_entity(
id=connector_id,
get_method=zen_store().get_service_connector,
delete_method=zen_store().delete_service_connector,
)
get_resources_based_on_service_connector_info(connector_info=None, connector_uuid=None, _=Security(oauth2_authentication))
Gets the list of resources that a service connector can access.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_info |
Optional[zenml.models.v2.misc.info_models.ServiceConnectorInfo] |
The service connector info. |
None |
connector_uuid |
Optional[uuid.UUID] |
The service connector uuid. |
None |
Returns:
Type | Description |
---|---|
ServiceConnectorResourcesInfo |
The list of resources that the service connector configuration has access to and consumable from UI/CLI. |
Exceptions:
Type | Description |
---|---|
ValueError |
If both connector_info and connector_uuid are provided. |
ValueError |
If neither connector_info nor connector_uuid are provided. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.post(
SERVICE_CONNECTOR_FULL_STACK,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def get_resources_based_on_service_connector_info(
connector_info: Optional[ServiceConnectorInfo] = None,
connector_uuid: Optional[UUID] = None,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResourcesInfo:
"""Gets the list of resources that a service connector can access.
Args:
connector_info: The service connector info.
connector_uuid: The service connector uuid.
Returns:
The list of resources that the service connector configuration has
access to and consumable from UI/CLI.
Raises:
ValueError: If both connector_info and connector_uuid are provided.
ValueError: If neither connector_info nor connector_uuid are provided.
"""
if connector_info is not None and connector_uuid is not None:
raise ValueError(
"Only one of connector_info or connector_uuid must be provided."
)
if connector_info is None and connector_uuid is None:
raise ValueError(
"Either connector_info or connector_uuid must be provided."
)
if connector_info is not None:
verify_permission(
resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE
)
elif connector_uuid is not None:
verify_permission(
resource_type=ResourceType.SERVICE_CONNECTOR,
action=Action.READ,
resource_id=connector_uuid,
)
return get_resources_options_from_resource_model_for_full_stack(
connector_details=connector_info or connector_uuid # type: ignore[arg-type]
)
get_service_connector(connector_id, expand_secrets=True, hydrate=True, _=Security(oauth2_authentication))
Returns the requested service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
expand_secrets |
bool |
Whether to expand secrets or not. |
True |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ServiceConnectorResponse |
The requested service connector. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.get(
"/{connector_id}",
response_model=ServiceConnectorResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_connector(
connector_id: UUID,
expand_secrets: bool = True,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResponse:
"""Returns the requested service connector.
Args:
connector_id: ID of the service connector.
expand_secrets: Whether to expand secrets or not.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested service connector.
"""
connector = zen_store().get_service_connector(
connector_id, hydrate=hydrate
)
verify_permission_for_model(connector, action=Action.READ)
if (
expand_secrets
and connector.secret_id
and has_permissions_for_model(
connector, action=Action.READ_SECRET_VALUE
)
):
secret = zen_store().get_secret(secret_id=connector.secret_id)
# Update the connector configuration with the secret.
connector.configuration.update(secret.secret_values)
return dehydrate_response_model(connector)
get_service_connector_client(connector_id, resource_type=None, resource_id=None, _=Security(oauth2_authentication))
Get a service connector client for a service connector and given resource.
This requires the service connector implementation to be installed on the ZenML server, otherwise a 501 Not Implemented error will be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
resource_type |
Optional[str] |
Type of the resource to list. |
None |
resource_id |
Optional[str] |
ID of the resource to list. |
None |
Returns:
Type | Description |
---|---|
ServiceConnectorResponse |
A service connector client that can be used to access the given resource. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.get(
"/{connector_id}" + SERVICE_CONNECTOR_CLIENT,
response_model=ServiceConnectorResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_connector_client(
connector_id: UUID,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResponse:
"""Get a service connector client for a service connector and given resource.
This requires the service connector implementation to be installed
on the ZenML server, otherwise a 501 Not Implemented error will be
returned.
Args:
connector_id: ID of the service connector.
resource_type: Type of the resource to list.
resource_id: ID of the resource to list.
Returns:
A service connector client that can be used to access the given
resource.
"""
connector = zen_store().get_service_connector(connector_id)
verify_permission_for_model(model=connector, action=Action.READ)
verify_permission_for_model(model=connector, action=Action.CLIENT)
return zen_store().get_service_connector_client(
service_connector_id=connector_id,
resource_type=resource_type,
resource_id=resource_id,
)
get_service_connector_type(connector_type, _=Security(oauth2_authentication))
Returns the requested service connector type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_type |
str |
the service connector type identifier. |
required |
Returns:
Type | Description |
---|---|
ServiceConnectorTypeModel |
The requested service connector type. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@types_router.get(
"/{connector_type}",
response_model=ServiceConnectorTypeModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_connector_type(
connector_type: str,
_: AuthContext = Security(authorize),
) -> ServiceConnectorTypeModel:
"""Returns the requested service connector type.
Args:
connector_type: the service connector type identifier.
Returns:
The requested service connector type.
"""
return zen_store().get_service_connector_type(connector_type)
list_service_connector_types(connector_type=None, resource_type=None, auth_method=None, _=Security(oauth2_authentication))
Get a list of service connector types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_type |
Optional[str] |
Filter by connector type. |
None |
resource_type |
Optional[str] |
Filter by resource type. |
None |
auth_method |
Optional[str] |
Filter by auth method. |
None |
Returns:
Type | Description |
---|---|
List[zenml.models.v2.misc.service_connector_type.ServiceConnectorTypeModel] |
List of service connector types. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@types_router.get(
"",
response_model=List[ServiceConnectorTypeModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_connector_types(
connector_type: Optional[str] = None,
resource_type: Optional[str] = None,
auth_method: Optional[str] = None,
_: AuthContext = Security(authorize),
) -> List[ServiceConnectorTypeModel]:
"""Get a list of service connector types.
Args:
connector_type: Filter by connector type.
resource_type: Filter by resource type.
auth_method: Filter by auth method.
Returns:
List of service connector types.
"""
connector_types = zen_store().list_service_connector_types(
connector_type=connector_type,
resource_type=resource_type,
auth_method=auth_method,
)
return connector_types
list_service_connectors(connector_filter_model=Depends(init_cls_and_handle_errors), expand_secrets=True, hydrate=False, _=Security(oauth2_authentication))
Get a list of all service connectors for a specific type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_filter_model |
ServiceConnectorFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
expand_secrets |
bool |
Whether to expand secrets or not. |
True |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ServiceConnectorResponse] |
Page with list of service connectors for a specific type. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.get(
"",
response_model=Page[ServiceConnectorResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_connectors(
connector_filter_model: ServiceConnectorFilter = Depends(
make_dependable(ServiceConnectorFilter)
),
expand_secrets: bool = True,
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ServiceConnectorResponse]:
"""Get a list of all service connectors for a specific type.
Args:
connector_filter_model: Filter model used for pagination, sorting,
filtering
expand_secrets: Whether to expand secrets or not.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page with list of service connectors for a specific type.
"""
connectors = verify_permissions_and_list_entities(
filter_model=connector_filter_model,
resource_type=ResourceType.SERVICE_CONNECTOR,
list_method=zen_store().list_service_connectors,
hydrate=hydrate,
)
if expand_secrets:
# This will be `None` if the user is allowed to read secret values
# for all service connectors
allowed_ids = get_allowed_resource_ids(
resource_type=ResourceType.SERVICE_CONNECTOR,
action=Action.READ_SECRET_VALUE,
)
for connector in connectors.items:
if not connector.secret_id:
continue
if allowed_ids is None or is_owned_by_authenticated_user(
connector
):
# The user either owns the connector or has permissions to
# read secret values for all service connectors
pass
elif connector.id not in allowed_ids:
# The user is not allowed to read secret values for this
# connector. We don't raise an exception here but don't include
# the secret values
continue
secret = zen_store().get_secret(secret_id=connector.secret_id)
# Update the connector configuration with the secret.
connector.configuration.update(secret.secret_values)
return connectors
update_service_connector(connector_id, connector_update, _=Security(oauth2_authentication))
Updates a service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
connector_update |
ServiceConnectorUpdate |
Service connector to use to update. |
required |
Returns:
Type | Description |
---|---|
ServiceConnectorResponse |
Updated service connector. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.put(
"/{connector_id}",
response_model=ServiceConnectorResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_service_connector(
connector_id: UUID,
connector_update: ServiceConnectorUpdate,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResponse:
"""Updates a service connector.
Args:
connector_id: ID of the service connector.
connector_update: Service connector to use to update.
Returns:
Updated service connector.
"""
return verify_permissions_and_update_entity(
id=connector_id,
update_model=connector_update,
get_method=zen_store().get_service_connector,
update_method=zen_store().update_service_connector,
)
validate_and_verify_service_connector(connector_id, resource_type=None, resource_id=None, list_resources=True, _=Security(oauth2_authentication))
Verifies if a service connector instance has access to one or more resources.
This requires the service connector implementation to be installed on the ZenML server, otherwise a 501 Not Implemented error will be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
The ID of the service connector to verify. |
required |
resource_type |
Optional[str] |
The type of resource to verify access to. |
None |
resource_id |
Optional[str] |
The ID of the resource to verify access to. |
None |
list_resources |
bool |
If True, the list of all resources accessible through the service connector and matching the supplied resource type and ID are returned. |
True |
Returns:
Type | Description |
---|---|
ServiceConnectorResourcesModel |
The list of resources that the service connector has access to, scoped to the supplied resource type and ID, if provided. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.put(
"/{connector_id}" + SERVICE_CONNECTOR_VERIFY,
response_model=ServiceConnectorResourcesModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def validate_and_verify_service_connector(
connector_id: UUID,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
list_resources: bool = True,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResourcesModel:
"""Verifies if a service connector instance has access to one or more resources.
This requires the service connector implementation to be installed
on the ZenML server, otherwise a 501 Not Implemented error will be
returned.
Args:
connector_id: The ID of the service connector to verify.
resource_type: The type of resource to verify access to.
resource_id: The ID of the resource to verify access to.
list_resources: If True, the list of all resources accessible
through the service connector and matching the supplied resource
type and ID are returned.
Returns:
The list of resources that the service connector has access to, scoped
to the supplied resource type and ID, if provided.
"""
connector = zen_store().get_service_connector(connector_id)
verify_permission_for_model(model=connector, action=Action.READ)
return zen_store().verify_service_connector(
service_connector_id=connector_id,
resource_type=resource_type,
resource_id=resource_id,
list_resources=list_resources,
)
validate_and_verify_service_connector_config(connector, list_resources=True, _=Security(oauth2_authentication))
Verifies if a service connector configuration has access to resources.
This requires the service connector implementation to be installed on the ZenML server, otherwise a 501 Not Implemented error will be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector |
ServiceConnectorRequest |
The service connector configuration to verify. |
required |
list_resources |
bool |
If True, the list of all resources accessible through the service connector is returned. |
True |
Returns:
Type | Description |
---|---|
ServiceConnectorResourcesModel |
The list of resources that the service connector configuration has access to. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.post(
SERVICE_CONNECTOR_VERIFY,
response_model=ServiceConnectorResourcesModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def validate_and_verify_service_connector_config(
connector: ServiceConnectorRequest,
list_resources: bool = True,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResourcesModel:
"""Verifies if a service connector configuration has access to resources.
This requires the service connector implementation to be installed
on the ZenML server, otherwise a 501 Not Implemented error will be
returned.
Args:
connector: The service connector configuration to verify.
list_resources: If True, the list of all resources accessible
through the service connector is returned.
Returns:
The list of resources that the service connector configuration has
access to.
"""
verify_permission(
resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE
)
return zen_store().verify_service_connector_config(
service_connector=connector,
list_resources=list_resources,
)
service_endpoints
Endpoint definitions for services.
create_service(service, _=Security(oauth2_authentication))
Creates a new service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
ServiceRequest |
The model containing the attributes of the new service. |
required |
Returns:
Type | Description |
---|---|
ServiceResponse |
The created service object. |
Source code in zenml/zen_server/routers/service_endpoints.py
@router.post(
"",
response_model=ServiceResponse,
responses={401: error_response, 422: error_response},
)
@handle_exceptions
def create_service(
service: ServiceRequest,
_: AuthContext = Security(authorize),
) -> ServiceResponse:
"""Creates a new service.
Args:
service: The model containing the attributes of the new service.
Returns:
The created service object.
"""
return verify_permissions_and_create_entity(
request_model=service,
create_method=zen_store().create_service,
resource_type=ResourceType.SERVICE,
)
delete_service(service_id, _=Security(oauth2_authentication))
Deletes a specific service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_id |
UUID |
The ID of the service to delete. |
required |
Source code in zenml/zen_server/routers/service_endpoints.py
@router.delete(
"/{service_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_service(
service_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific service.
Args:
service_id: The ID of the service to delete.
"""
verify_permissions_and_delete_entity(
id=service_id,
get_method=zen_store().get_service,
delete_method=zen_store().delete_service,
)
get_service(service_id, hydrate=True, _=Security(oauth2_authentication))
Gets a specific service using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_id |
UUID |
The ID of the service to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ServiceResponse |
A specific service object. |
Source code in zenml/zen_server/routers/service_endpoints.py
@router.get(
"/{service_id}",
response_model=ServiceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service(
service_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ServiceResponse:
"""Gets a specific service using its unique ID.
Args:
service_id: The ID of the service to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A specific service object.
"""
return verify_permissions_and_get_entity(
id=service_id,
get_method=zen_store().get_service,
hydrate=hydrate,
)
list_services(filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets a page of service objects.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
ServiceFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ServiceResponse] |
Page of service objects. |
Source code in zenml/zen_server/routers/service_endpoints.py
@router.get(
"",
response_model=Page[ServiceResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_services(
filter_model: ServiceFilter = Depends(make_dependable(ServiceFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ServiceResponse]:
"""Gets a page of service objects.
Args:
filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page of service objects.
"""
return verify_permissions_and_list_entities(
filter_model=filter_model,
resource_type=ResourceType.SERVICE,
list_method=zen_store().list_services,
hydrate=hydrate,
)
update_service(service_id, update, _=Security(oauth2_authentication))
Updates a service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_id |
UUID |
The ID of the service to update. |
required |
update |
ServiceUpdate |
The model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
ServiceResponse |
The updated service object. |
Source code in zenml/zen_server/routers/service_endpoints.py
@router.put(
"/{service_id}",
response_model=ServiceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_service(
service_id: UUID,
update: ServiceUpdate,
_: AuthContext = Security(authorize),
) -> ServiceResponse:
"""Updates a service.
Args:
service_id: The ID of the service to update.
update: The model containing the attributes to update.
Returns:
The updated service object.
"""
return verify_permissions_and_update_entity(
id=service_id,
update_model=update,
get_method=zen_store().get_service,
update_method=zen_store().update_service,
)
stack_components_endpoints
Endpoint definitions for stack components.
deregister_stack_component(component_id, _=Security(oauth2_authentication))
Deletes a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
ID of the stack component. |
required |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.delete(
"/{component_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def deregister_stack_component(
component_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a stack component.
Args:
component_id: ID of the stack component.
"""
verify_permissions_and_delete_entity(
id=component_id,
get_method=zen_store().get_stack_component,
delete_method=zen_store().delete_stack_component,
)
get_stack_component(component_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
ID of the stack component. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
ComponentResponse |
The requested stack component. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.get(
"/{component_id}",
response_model=ComponentResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_stack_component(
component_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> ComponentResponse:
"""Returns the requested stack component.
Args:
component_id: ID of the stack component.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested stack component.
"""
return verify_permissions_and_get_entity(
id=component_id,
get_method=zen_store().get_stack_component,
hydrate=hydrate,
)
get_stack_component_types(_=Security(oauth2_authentication))
Get a list of all stack component types.
Returns:
Type | Description |
---|---|
List[str] |
List of stack components. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@types_router.get(
"",
response_model=List[str],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_stack_component_types(
_: AuthContext = Security(authorize),
) -> List[str]:
"""Get a list of all stack component types.
Returns:
List of stack components.
"""
return StackComponentType.values()
list_stack_components(component_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get a list of all stack components for a specific type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_filter_model |
ComponentFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ComponentResponse] |
List of stack components for a specific type. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.get(
"",
response_model=Page[ComponentResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_stack_components(
component_filter_model: ComponentFilter = Depends(
make_dependable(ComponentFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ComponentResponse]:
"""Get a list of all stack components for a specific type.
Args:
component_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of stack components for a specific type.
"""
return verify_permissions_and_list_entities(
filter_model=component_filter_model,
resource_type=ResourceType.STACK_COMPONENT,
list_method=zen_store().list_stack_components,
hydrate=hydrate,
)
update_stack_component(component_id, component_update, _=Security(oauth2_authentication))
Updates a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
ID of the stack component. |
required |
component_update |
ComponentUpdate |
Stack component to use to update. |
required |
Returns:
Type | Description |
---|---|
ComponentResponse |
Updated stack component. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.put(
"/{component_id}",
response_model=ComponentResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_stack_component(
component_id: UUID,
component_update: ComponentUpdate,
_: AuthContext = Security(authorize),
) -> ComponentResponse:
"""Updates a stack component.
Args:
component_id: ID of the stack component.
component_update: Stack component to use to update.
Returns:
Updated stack component.
"""
if component_update.configuration:
from zenml.stack.utils import validate_stack_component_config
existing_component = zen_store().get_stack_component(component_id)
validate_stack_component_config(
configuration_dict=component_update.configuration,
flavor=existing_component.flavor_name,
component_type=existing_component.type,
zen_store=zen_store(),
# We allow custom flavors to fail import on the server side.
validate_custom_flavors=False,
)
if component_update.connector:
service_connector = zen_store().get_service_connector(
component_update.connector
)
verify_permission_for_model(service_connector, action=Action.READ)
return verify_permissions_and_update_entity(
id=component_id,
update_model=component_update,
get_method=zen_store().get_stack_component,
update_method=zen_store().update_stack_component,
)
stack_deployment_endpoints
Endpoint definitions for stack deployments.
get_deployed_stack(provider, stack_name, location=None, date_start=None, terraform=False, _=Security(oauth2_authentication))
Return a matching ZenML stack that was deployed and registered.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
provider |
StackDeploymentProvider |
The stack deployment provider. |
required |
stack_name |
str |
The name of the stack. |
required |
location |
Optional[str] |
The location where the stack should be deployed. |
None |
date_start |
Optional[datetime.datetime] |
The date when the deployment started. |
None |
terraform |
bool |
Whether the stack was deployed using Terraform. |
False |
Returns:
Type | Description |
---|---|
Optional[zenml.models.v2.misc.stack_deployment.DeployedStack] |
The ZenML stack that was deployed and registered or None if the stack was not found. |
Source code in zenml/zen_server/routers/stack_deployment_endpoints.py
@router.get(
STACK,
)
@handle_exceptions
def get_deployed_stack(
provider: StackDeploymentProvider,
stack_name: str,
location: Optional[str] = None,
date_start: Optional[datetime.datetime] = None,
terraform: bool = False,
_: AuthContext = Security(authorize),
) -> Optional[DeployedStack]:
"""Return a matching ZenML stack that was deployed and registered.
Args:
provider: The stack deployment provider.
stack_name: The name of the stack.
location: The location where the stack should be deployed.
date_start: The date when the deployment started.
terraform: Whether the stack was deployed using Terraform.
Returns:
The ZenML stack that was deployed and registered or None if the stack
was not found.
"""
stack_deployment_class = get_stack_deployment_class(provider)
return stack_deployment_class(
terraform=terraform,
stack_name=stack_name,
location=location,
# These fields are not needed for this operation
zenml_server_url="",
zenml_server_api_token="",
).get_stack(date_start=date_start)
get_stack_deployment_config(request, provider, stack_name, location=None, terraform=False, auth_context=Security(oauth2_authentication))
Return the URL to deploy the ZenML stack to the specified cloud provider.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The FastAPI request object. |
required |
provider |
StackDeploymentProvider |
The stack deployment provider. |
required |
stack_name |
str |
The name of the stack. |
required |
location |
Optional[str] |
The location where the stack should be deployed. |
None |
terraform |
bool |
Whether the stack should be deployed using Terraform. |
False |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
StackDeploymentConfig |
The cloud provider console URL where the stack will be deployed and the configuration for the stack deployment. |
Source code in zenml/zen_server/routers/stack_deployment_endpoints.py
@router.get(
CONFIG,
)
@handle_exceptions
def get_stack_deployment_config(
request: Request,
provider: StackDeploymentProvider,
stack_name: str,
location: Optional[str] = None,
terraform: bool = False,
auth_context: AuthContext = Security(authorize),
) -> StackDeploymentConfig:
"""Return the URL to deploy the ZenML stack to the specified cloud provider.
Args:
request: The FastAPI request object.
provider: The stack deployment provider.
stack_name: The name of the stack.
location: The location where the stack should be deployed.
terraform: Whether the stack should be deployed using Terraform.
auth_context: The authentication context.
Returns:
The cloud provider console URL where the stack will be deployed and
the configuration for the stack deployment.
"""
verify_permission(
resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE
)
verify_permission(
resource_type=ResourceType.STACK_COMPONENT,
action=Action.CREATE,
)
verify_permission(resource_type=ResourceType.STACK, action=Action.CREATE)
stack_deployment_class = get_stack_deployment_class(provider)
# Get the base server URL used to call this FastAPI endpoint
url = request.url.replace(path="").replace(query="")
# Use HTTPS for the URL
url = url.replace(scheme="https")
token = auth_context.access_token
assert token is not None
# A new API token is generated for the stack deployment
expires = datetime.datetime.utcnow() + datetime.timedelta(
minutes=STACK_DEPLOYMENT_API_TOKEN_EXPIRATION
)
api_token = token.encode(expires=expires)
return stack_deployment_class(
terraform=terraform,
stack_name=stack_name,
location=location,
zenml_server_url=str(url),
zenml_server_api_token=api_token,
).get_deployment_config()
get_stack_deployment_info(provider, _=Security(oauth2_authentication))
Get information about a stack deployment provider.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
provider |
StackDeploymentProvider |
The stack deployment provider. |
required |
Returns:
Type | Description |
---|---|
StackDeploymentInfo |
Information about the stack deployment provider. |
Source code in zenml/zen_server/routers/stack_deployment_endpoints.py
@router.get(
INFO,
)
@handle_exceptions
def get_stack_deployment_info(
provider: StackDeploymentProvider,
_: AuthContext = Security(authorize),
) -> StackDeploymentInfo:
"""Get information about a stack deployment provider.
Args:
provider: The stack deployment provider.
Returns:
Information about the stack deployment provider.
"""
stack_deployment_class = get_stack_deployment_class(provider)
return stack_deployment_class.get_deployment_info()
stacks_endpoints
Endpoint definitions for stacks.
delete_stack(stack_id, _=Security(oauth2_authentication))
Deletes a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_id |
UUID |
Name of the stack. |
required |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.delete(
"/{stack_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_stack(
stack_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a stack.
Args:
stack_id: Name of the stack.
"""
verify_permissions_and_delete_entity(
id=stack_id,
get_method=zen_store().get_stack,
delete_method=zen_store().delete_stack,
)
get_stack(stack_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_id |
UUID |
ID of the stack. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
StackResponse |
The requested stack. |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.get(
"/{stack_id}",
response_model=StackResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_stack(
stack_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> StackResponse:
"""Returns the requested stack.
Args:
stack_id: ID of the stack.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested stack.
"""
return verify_permissions_and_get_entity(
id=stack_id, get_method=zen_store().get_stack, hydrate=hydrate
)
list_stacks(stack_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Returns all stacks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_filter_model |
StackFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[StackResponse] |
All stacks. |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.get(
"",
response_model=Page[StackResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_stacks(
stack_filter_model: StackFilter = Depends(make_dependable(StackFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[StackResponse]:
"""Returns all stacks.
Args:
stack_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All stacks.
"""
return verify_permissions_and_list_entities(
filter_model=stack_filter_model,
resource_type=ResourceType.STACK,
list_method=zen_store().list_stacks,
hydrate=hydrate,
)
update_stack(stack_id, stack_update, _=Security(oauth2_authentication))
Updates a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_id |
UUID |
Name of the stack. |
required |
stack_update |
StackUpdate |
Stack to use for the update. |
required |
Returns:
Type | Description |
---|---|
StackResponse |
The updated stack. |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.put(
"/{stack_id}",
response_model=StackResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_stack(
stack_id: UUID,
stack_update: StackUpdate,
_: AuthContext = Security(authorize),
) -> StackResponse:
"""Updates a stack.
Args:
stack_id: Name of the stack.
stack_update: Stack to use for the update.
Returns:
The updated stack.
"""
if stack_update.components:
updated_components = [
zen_store().get_stack_component(id)
for ids in stack_update.components.values()
for id in ids
]
batch_verify_permissions_for_models(
updated_components, action=Action.READ
)
return verify_permissions_and_update_entity(
id=stack_id,
update_model=stack_update,
get_method=zen_store().get_stack,
update_method=zen_store().update_stack,
)
steps_endpoints
Endpoint definitions for steps (and artifacts) of pipeline runs.
create_run_step(step, _=Security(oauth2_authentication))
Create a run step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
StepRunRequest |
The run step to create. |
required |
Returns:
Type | Description |
---|---|
StepRunResponse |
The created run step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.post(
"",
response_model=StepRunResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_run_step(
step: StepRunRequest,
_: AuthContext = Security(authorize),
) -> StepRunResponse:
"""Create a run step.
Args:
step: The run step to create.
Returns:
The created run step.
"""
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.UPDATE)
step_response = zen_store().create_run_step(step_run=step)
return dehydrate_response_model(step_response)
get_step(step_id, hydrate=True, _=Security(oauth2_authentication))
Get one specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
StepRunResponse |
The step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}",
response_model=StepRunResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step(
step_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> StepRunResponse:
"""Get one specific step.
Args:
step_id: ID of the step to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The step.
"""
# We always fetch the step hydrated because we need the pipeline_run_id
# for the permission checks. If the user requested an unhydrated response,
# we later remove the metadata
step = zen_store().get_run_step(step_id, hydrate=True)
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.READ)
if hydrate is False:
step.metadata = None
return dehydrate_response_model(step)
get_step_configuration(step_id, _=Security(oauth2_authentication))
Get the configuration of a specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step to get. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The step configuration. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}" + STEP_CONFIGURATION,
response_model=Dict[str, Any],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step_configuration(
step_id: UUID,
_: AuthContext = Security(authorize),
) -> Dict[str, Any]:
"""Get the configuration of a specific step.
Args:
step_id: ID of the step to get.
Returns:
The step configuration.
"""
step = zen_store().get_run_step(step_id, hydrate=True)
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.READ)
return step.config.model_dump()
get_step_logs(step_id, offset=0, length=16777216, _=Security(oauth2_authentication))
Get the logs of a specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step for which to get the logs. |
required |
offset |
int |
The offset from which to start reading. |
0 |
length |
int |
The amount of bytes that should be read. |
16777216 |
Returns:
Type | Description |
---|---|
str |
The logs of the step. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If no logs are available for this step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}" + LOGS,
response_model=str,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step_logs(
step_id: UUID,
offset: int = 0,
length: int = 1024 * 1024 * 16, # Default to 16MiB of data
_: AuthContext = Security(authorize),
) -> str:
"""Get the logs of a specific step.
Args:
step_id: ID of the step for which to get the logs.
offset: The offset from which to start reading.
length: The amount of bytes that should be read.
Returns:
The logs of the step.
Raises:
HTTPException: If no logs are available for this step.
"""
step = zen_store().get_run_step(step_id, hydrate=True)
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.READ)
store = zen_store()
logs = step.logs
if logs is None:
raise HTTPException(
status_code=404, detail="No logs available for this step"
)
return fetch_logs(
zen_store=store,
artifact_store_id=logs.artifact_store_id,
logs_uri=logs.uri,
offset=offset,
length=length,
)
get_step_status(step_id, _=Security(oauth2_authentication))
Get the status of a specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step for which to get the status. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
The status of the step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}" + STATUS,
response_model=ExecutionStatus,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step_status(
step_id: UUID,
_: AuthContext = Security(authorize),
) -> ExecutionStatus:
"""Get the status of a specific step.
Args:
step_id: ID of the step for which to get the status.
Returns:
The status of the step.
"""
step = zen_store().get_run_step(step_id, hydrate=True)
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.READ)
return step.status
list_run_steps(step_run_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, auth_context=Security(oauth2_authentication))
Get run steps according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_filter_model |
StepRunFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[StepRunResponse] |
The run steps according to query filters. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"",
response_model=Page[StepRunResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_run_steps(
step_run_filter_model: StepRunFilter = Depends(
make_dependable(StepRunFilter)
),
hydrate: bool = False,
auth_context: AuthContext = Security(authorize),
) -> Page[StepRunResponse]:
"""Get run steps according to query filters.
Args:
step_run_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: Authentication context.
Returns:
The run steps according to query filters.
"""
allowed_pipeline_run_ids = get_allowed_resource_ids(
resource_type=ResourceType.PIPELINE_RUN
)
step_run_filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id,
pipeline_run_id=allowed_pipeline_run_ids,
)
page = zen_store().list_run_steps(
step_run_filter_model=step_run_filter_model, hydrate=hydrate
)
return dehydrate_page(page)
update_step(step_id, step_model, _=Security(oauth2_authentication))
Updates a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step. |
required |
step_model |
StepRunUpdate |
Step model to use for the update. |
required |
Returns:
Type | Description |
---|---|
StepRunResponse |
The updated step model. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.put(
"/{step_id}",
response_model=StepRunResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_step(
step_id: UUID,
step_model: StepRunUpdate,
_: AuthContext = Security(authorize),
) -> StepRunResponse:
"""Updates a step.
Args:
step_id: ID of the step.
step_model: Step model to use for the update.
Returns:
The updated step model.
"""
step = zen_store().get_run_step(step_id, hydrate=True)
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.UPDATE)
updated_step = zen_store().update_run_step(
step_run_id=step_id, step_run_update=step_model
)
return dehydrate_response_model(updated_step)
tags_endpoints
Endpoint definitions for tags.
create_tag(tag, _=Security(oauth2_authentication))
Create a new tag.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tag |
TagRequest |
The tag to create. |
required |
Returns:
Type | Description |
---|---|
TagResponse |
The created tag. |
Source code in zenml/zen_server/routers/tags_endpoints.py
@router.post(
"",
response_model=TagResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_tag(
tag: TagRequest,
_: AuthContext = Security(authorize),
) -> TagResponse:
"""Create a new tag.
Args:
tag: The tag to create.
Returns:
The created tag.
"""
return verify_permissions_and_create_entity(
request_model=tag,
resource_type=ResourceType.TAG,
create_method=zen_store().create_tag,
)
delete_tag(tag_name_or_id, _=Security(oauth2_authentication))
Delete a tag by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tag_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the tag to delete. |
required |
Source code in zenml/zen_server/routers/tags_endpoints.py
@router.delete(
"/{tag_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_tag(
tag_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Delete a tag by name or ID.
Args:
tag_name_or_id: The name or ID of the tag to delete.
"""
verify_permissions_and_delete_entity(
id=tag_name_or_id,
get_method=zen_store().get_tag,
delete_method=zen_store().delete_tag,
)
get_tag(tag_name_or_id, hydrate=True, _=Security(oauth2_authentication))
Get a tag by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tag_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the tag to get. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
TagResponse |
The tag with the given name or ID. |
Source code in zenml/zen_server/routers/tags_endpoints.py
@router.get(
"/{tag_name_or_id}",
response_model=TagResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_tag(
tag_name_or_id: Union[str, UUID],
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> TagResponse:
"""Get a tag by name or ID.
Args:
tag_name_or_id: The name or ID of the tag to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The tag with the given name or ID.
"""
return verify_permissions_and_get_entity(
id=tag_name_or_id, get_method=zen_store().get_tag, hydrate=hydrate
)
list_tags(tag_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get tags according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tag_filter_model |
TagFilter |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[TagResponse] |
The tags according to query filters. |
Source code in zenml/zen_server/routers/tags_endpoints.py
@router.get(
"",
response_model=Page[TagResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_tags(
tag_filter_model: TagFilter = Depends(make_dependable(TagFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[TagResponse]:
"""Get tags according to query filters.
Args:
tag_filter_model: Filter model used for pagination, sorting,
filtering
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The tags according to query filters.
"""
return verify_permissions_and_list_entities(
filter_model=tag_filter_model,
resource_type=ResourceType.TAG,
list_method=zen_store().list_tags,
hydrate=hydrate,
)
update_tag(tag_id, tag_update_model, _=Security(oauth2_authentication))
Updates a tag.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tag_id |
UUID |
Id or name of the tag. |
required |
tag_update_model |
TagUpdate |
Tag to use for the update. |
required |
Returns:
Type | Description |
---|---|
TagResponse |
The updated tag. |
Source code in zenml/zen_server/routers/tags_endpoints.py
@router.put(
"/{tag_id}",
response_model=TagResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_tag(
tag_id: UUID,
tag_update_model: TagUpdate,
_: AuthContext = Security(authorize),
) -> TagResponse:
"""Updates a tag.
Args:
tag_id: Id or name of the tag.
tag_update_model: Tag to use for the update.
Returns:
The updated tag.
"""
return verify_permissions_and_update_entity(
id=tag_id,
update_model=tag_update_model,
get_method=zen_store().get_tag,
update_method=zen_store().update_tag,
)
triggers_endpoints
Endpoint definitions for triggers.
create_trigger(trigger, _=Security(oauth2_authentication))
Creates a trigger.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger |
TriggerRequest |
Trigger to register. |
required |
Returns:
Type | Description |
---|---|
TriggerResponse |
The created trigger. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the action flavor/subtype combination is not actually a webhook event source |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@router.post(
"",
response_model=TriggerResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_trigger(
trigger: TriggerRequest,
_: AuthContext = Security(authorize),
) -> TriggerResponse:
"""Creates a trigger.
Args:
trigger: Trigger to register.
Returns:
The created trigger.
Raises:
ValueError: If the action flavor/subtype combination is not actually a webhook event source
"""
if trigger.event_source_id and trigger.event_filter:
event_source = zen_store().get_event_source(
event_source_id=trigger.event_source_id
)
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event source
# implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event source "
"handler implementation."
)
# Validate the trigger event filter
event_source_handler.validate_event_filter_configuration(
trigger.event_filter
)
return verify_permissions_and_create_entity(
request_model=trigger,
resource_type=ResourceType.TRIGGER,
create_method=zen_store().create_trigger,
)
delete_trigger(trigger_id, _=Security(oauth2_authentication))
Deletes a trigger.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_id |
UUID |
Name of the trigger. |
required |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@router.delete(
"/{trigger_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_trigger(
trigger_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a trigger.
Args:
trigger_id: Name of the trigger.
"""
trigger = zen_store().get_trigger(trigger_id=trigger_id)
verify_permission_for_model(trigger, action=Action.DELETE)
zen_store().delete_trigger(trigger_id=trigger_id)
delete_trigger_execution(trigger_execution_id, _=Security(oauth2_authentication))
Deletes a trigger execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_execution_id |
UUID |
ID of the trigger execution. |
required |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@executions_router.delete(
"/{trigger_execution_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_trigger_execution(
trigger_execution_id: UUID,
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a trigger execution.
Args:
trigger_execution_id: ID of the trigger execution.
"""
verify_permissions_and_delete_entity(
id=trigger_execution_id,
get_method=zen_store().get_trigger_execution,
delete_method=zen_store().delete_trigger_execution,
)
get_trigger(trigger_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested trigger.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_id |
UUID |
ID of the trigger. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
TriggerResponse |
The requested trigger. |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@router.get(
"/{trigger_id}",
response_model=TriggerResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_trigger(
trigger_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> TriggerResponse:
"""Returns the requested trigger.
Args:
trigger_id: ID of the trigger.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested trigger.
"""
trigger = zen_store().get_trigger(trigger_id=trigger_id, hydrate=hydrate)
verify_permission_for_model(trigger, action=Action.READ)
return dehydrate_response_model(trigger)
get_trigger_execution(trigger_execution_id, hydrate=True, _=Security(oauth2_authentication))
Returns the requested trigger execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_execution_id |
UUID |
ID of the trigger execution. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
TriggerExecutionResponse |
The requested trigger execution. |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@executions_router.get(
"/{trigger_execution_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_trigger_execution(
trigger_execution_id: UUID,
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> TriggerExecutionResponse:
"""Returns the requested trigger execution.
Args:
trigger_execution_id: ID of the trigger execution.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested trigger execution.
"""
return verify_permissions_and_get_entity(
id=trigger_execution_id,
get_method=zen_store().get_trigger_execution,
hydrate=hydrate,
)
list_trigger_executions(trigger_execution_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
List trigger executions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_execution_filter_model |
TriggerExecutionFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[TriggerExecutionResponse] |
Page of trigger executions. |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@executions_router.get(
"",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_trigger_executions(
trigger_execution_filter_model: TriggerExecutionFilter = Depends(
make_dependable(TriggerExecutionFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[TriggerExecutionResponse]:
"""List trigger executions.
Args:
trigger_execution_filter_model: Filter model used for pagination,
sorting, filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page of trigger executions.
"""
return verify_permissions_and_list_entities(
filter_model=trigger_execution_filter_model,
resource_type=ResourceType.TRIGGER_EXECUTION,
list_method=zen_store().list_trigger_executions,
hydrate=hydrate,
)
list_triggers(trigger_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Returns all triggers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_filter_model |
TriggerFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[TriggerResponse] |
All triggers. |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@router.get(
"",
response_model=Page[TriggerResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_triggers(
trigger_filter_model: TriggerFilter = Depends(
make_dependable(TriggerFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[TriggerResponse]:
"""Returns all triggers.
Args:
trigger_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All triggers.
"""
return verify_permissions_and_list_entities(
filter_model=trigger_filter_model,
resource_type=ResourceType.TRIGGER,
list_method=zen_store().list_triggers,
hydrate=hydrate,
)
update_trigger(trigger_id, trigger_update, _=Security(oauth2_authentication))
Updates a trigger.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trigger_id |
UUID |
Name of the trigger. |
required |
trigger_update |
TriggerUpdate |
Trigger to use for the update. |
required |
Returns:
Type | Description |
---|---|
TriggerResponse |
The updated trigger. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the action flavor/subtype combination is not actually a webhook event source |
Source code in zenml/zen_server/routers/triggers_endpoints.py
@router.put(
"/{trigger_id}",
response_model=TriggerResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_trigger(
trigger_id: UUID,
trigger_update: TriggerUpdate,
_: AuthContext = Security(authorize),
) -> TriggerResponse:
"""Updates a trigger.
Args:
trigger_id: Name of the trigger.
trigger_update: Trigger to use for the update.
Returns:
The updated trigger.
Raises:
ValueError: If the action flavor/subtype combination is not actually a webhook event source
"""
trigger = zen_store().get_trigger(trigger_id=trigger_id)
if trigger_update.event_filter:
if not trigger.event_source:
raise ValueError(
"Trying to set event filter for trigger without event source."
)
event_source = zen_store().get_event_source(
event_source_id=trigger.event_source.id
)
event_source_handler = plugin_flavor_registry().get_plugin(
name=event_source.flavor,
_type=PluginType.EVENT_SOURCE,
subtype=event_source.plugin_subtype,
)
# Validate that the flavor and plugin_type correspond to an event source
# implementation
if not isinstance(event_source_handler, BaseEventSourceHandler):
raise ValueError(
f"Event source plugin {event_source.plugin_subtype} "
f"for flavor {event_source.flavor} is not a valid event source "
"handler implementation."
)
# Validate the trigger event filter
event_source_handler.validate_event_filter_configuration(
trigger_update.event_filter
)
verify_permission_for_model(trigger, action=Action.UPDATE)
updated_trigger = zen_store().update_trigger(
trigger_id=trigger_id, trigger_update=trigger_update
)
return dehydrate_response_model(updated_trigger)
users_endpoints
Endpoint definitions for users.
activate_user(user_name_or_id, user_update)
Activates a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
user_update |
UserUpdate |
the user to use for the update. |
required |
Returns:
Type | Description |
---|---|
UserResponse |
The updated user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@activation_router.put(
"/{user_name_or_id}" + ACTIVATE,
response_model=UserResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def activate_user(
user_name_or_id: Union[str, UUID],
user_update: UserUpdate,
) -> UserResponse:
"""Activates a specific user.
Args:
user_name_or_id: Name or ID of the user.
user_update: the user to use for the update.
Returns:
The updated user.
"""
user = zen_store().get_user(user_name_or_id)
# Use a separate object to compute the update that will be applied to
# the user to avoid giving the API requester direct control over the
# user attributes that are updated.
#
# Exclude attributes that cannot be updated through this endpoint:
#
# - activation_token
# - external_user_id
# - is_admin
# - active
# - old_password
#
safe_user_update = user_update.create_copy(
exclude={
"activation_token",
"external_user_id",
"is_admin",
"active",
"old_password",
},
)
# NOTE: if the activation token is not set, this will raise an
# exception
authenticate_credentials(
user_name_or_id=user_name_or_id,
activation_token=user_update.activation_token,
)
# Activate the user: set active to True and clear the activation token
safe_user_update.active = True
safe_user_update.activation_token = None
return zen_store().update_user(
user_id=user.id, user_update=safe_user_update
)
create_user(user, auth_context=Security(oauth2_authentication))
Creates a user.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserRequest |
User to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
The created user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.post(
"",
response_model=UserResponse,
responses={
401: error_response,
409: error_response,
422: error_response,
},
)
@handle_exceptions
def create_user(
user: UserRequest,
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Creates a user.
# noqa: DAR401
Args:
user: User to create.
auth_context: Authentication context.
Returns:
The created user.
"""
# Two ways of creating a new user:
# 1. Create a new user with a password and have it immediately active
# 2. Create a new user without a password and have it activated at a
# later time with an activation token
token: Optional[str] = None
if user.password is None:
user.active = False
token = user.generate_activation_token()
else:
user.active = True
verify_admin_status_if_no_rbac(
auth_context.user.is_admin, "create user"
)
# new_user = verify_permissions_and_create_entity(
# request_model=user,
# resource_type=ResourceType.USER,
# create_method=zen_store().create_user,
# )
new_user = zen_store().create_user(user)
# add back the original unhashed activation token, if generated, to
# send it back to the client
if token:
new_user.get_body().activation_token = token
return new_user
deactivate_user(user_name_or_id, auth_context=Security(oauth2_authentication))
Deactivates a user and generates a new activation token for it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
The generated activation token. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
if the user is trying to deactivate themselves. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.put(
"/{user_name_or_id}" + DEACTIVATE,
response_model=UserResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def deactivate_user(
user_name_or_id: Union[str, UUID],
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Deactivates a user and generates a new activation token for it.
Args:
user_name_or_id: Name or ID of the user.
auth_context: Authentication context.
Returns:
The generated activation token.
Raises:
IllegalOperationError: if the user is trying to deactivate
themselves.
"""
user = zen_store().get_user(user_name_or_id)
if user.id == auth_context.user.id:
raise IllegalOperationError("Cannot deactivate yourself.")
verify_admin_status_if_no_rbac(
auth_context.user.is_admin, "deactivate user"
)
# verify_permission_for_model(
# user,
# action=Action.UPDATE,
# )
user_update = UserUpdate(
active=False,
)
token = user_update.generate_activation_token()
user = zen_store().update_user(
user_id=user.id, user_update=user_update
)
# add back the original unhashed activation token
user.get_body().activation_token = token
return dehydrate_response_model(user)
delete_user(user_name_or_id, auth_context=Security(oauth2_authentication))
Deletes a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the user is not authorized to delete the user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.delete(
"/{user_name_or_id}",
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def delete_user(
user_name_or_id: Union[str, UUID],
auth_context: AuthContext = Security(authorize),
) -> None:
"""Deletes a specific user.
Args:
user_name_or_id: Name or ID of the user.
auth_context: The authentication context.
Raises:
IllegalOperationError: If the user is not authorized to delete the user.
"""
user = zen_store().get_user(user_name_or_id)
if auth_context.user.id == user.id:
raise IllegalOperationError(
"You cannot delete the user account currently used to authenticate "
"to the ZenML server. If you wish to delete this account, "
"please authenticate with another account or contact your ZenML "
"administrator."
)
else:
verify_admin_status_if_no_rbac(
auth_context.user.is_admin, "delete user"
)
# verify_permission_for_model(
# user,
# action=Action.DELETE,
# )
zen_store().delete_user(user_name_or_id=user_name_or_id)
email_opt_in_response(user_name_or_id, user_response, auth_context=Security(oauth2_authentication))
Sets the response of the user to the email prompt.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
user_response |
UserUpdate |
User Response to email prompt |
required |
auth_context |
AuthContext |
The authentication context of the user |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
The updated user. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
if the user does not have the required permissions |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.put(
"/{user_name_or_id}" + EMAIL_ANALYTICS,
response_model=UserResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def email_opt_in_response(
user_name_or_id: Union[str, UUID],
user_response: UserUpdate,
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Sets the response of the user to the email prompt.
Args:
user_name_or_id: Name or ID of the user.
user_response: User Response to email prompt
auth_context: The authentication context of the user
Returns:
The updated user.
Raises:
AuthorizationException: if the user does not have the required
permissions
"""
user = zen_store().get_user(user_name_or_id)
if str(auth_context.user.id) == str(user_name_or_id):
user_update = UserUpdate(
email=user_response.email,
email_opted_in=user_response.email_opted_in,
)
if user_response.email_opted_in is not None:
email_opt_int(
opted_in=user_response.email_opted_in,
email=user_response.email,
source="zenml server",
)
updated_user = zen_store().update_user(
user_id=user.id, user_update=user_update
)
return dehydrate_response_model(updated_user)
else:
raise AuthorizationException(
"Users can not opt in on behalf of another user."
)
get_current_user(auth_context=Security(oauth2_authentication))
Returns the model of the authenticated user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
The model of the authenticated user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@current_user_router.get(
"/current-user",
response_model=UserResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_current_user(
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Returns the model of the authenticated user.
Args:
auth_context: The authentication context.
Returns:
The model of the authenticated user.
"""
return dehydrate_response_model(auth_context.user)
get_user(user_name_or_id, hydrate=True, auth_context=Security(oauth2_authentication))
Returns a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
A specific user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.get(
"/{user_name_or_id}",
response_model=UserResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_user(
user_name_or_id: Union[str, UUID],
hydrate: bool = True,
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Returns a specific user.
Args:
user_name_or_id: Name or ID of the user.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: Authentication context.
Returns:
A specific user.
"""
user = zen_store().get_user(
user_name_or_id=user_name_or_id, hydrate=hydrate
)
if user.id != auth_context.user.id:
verify_admin_status_if_no_rbac(
auth_context.user.is_admin, "get other user"
)
# verify_permission_for_model(
# user,
# action=Action.READ,
# )
return dehydrate_response_model(user)
list_users(user_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, auth_context=Security(oauth2_authentication))
Returns a list of all users.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_filter_model |
UserFilter |
Model that takes care of filtering, sorting and pagination. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[UserResponse] |
A list of all users. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.get(
"",
response_model=Page[UserResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_users(
user_filter_model: UserFilter = Depends(make_dependable(UserFilter)),
hydrate: bool = False,
auth_context: AuthContext = Security(authorize),
) -> Page[UserResponse]:
"""Returns a list of all users.
Args:
user_filter_model: Model that takes care of filtering, sorting and
pagination.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
auth_context: Authentication context.
Returns:
A list of all users.
"""
# allowed_ids = get_allowed_resource_ids(resource_type=ResourceType.USER)
# if allowed_ids is not None:
# # Make sure users can see themselves
# allowed_ids.add(auth_context.user.id)
# else:
# if not auth_context.user.is_admin and not server_config().rbac_enabled:
# allowed_ids = {auth_context.user.id}
if not auth_context.user.is_admin and not server_config().rbac_enabled:
user_filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id,
id={auth_context.user.id},
)
page = zen_store().list_users(
user_filter_model=user_filter_model, hydrate=hydrate
)
return dehydrate_page(page)
update_myself(user, request, auth_context=Security(oauth2_authentication))
Updates a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserUpdate |
the user to use for the update. |
required |
request |
Request |
The request object. |
required |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
The updated user. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
if the current password is not supplied when changing the password or if the current password is incorrect. |
Source code in zenml/zen_server/routers/users_endpoints.py
@current_user_router.put(
"/current-user",
response_model=UserResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_myself(
user: UserUpdate,
request: Request,
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Updates a specific user.
Args:
user: the user to use for the update.
request: The request object.
auth_context: The authentication context.
Returns:
The updated user.
Raises:
IllegalOperationError: if the current password is not supplied when
changing the password or if the current password is incorrect.
"""
# Use a separate object to compute the update that will be applied to
# the user to avoid giving the API requester direct control over the
# user attributes that are updated.
#
# Exclude attributes that cannot be updated through this endpoint:
#
# - activation_token
# - external_user_id
# - admin
# - is_active
# - old_password
#
safe_user_update = user.create_copy(
exclude={
"activation_token",
"external_user_id",
"is_admin",
"active",
"old_password",
},
)
# Validate a password change
if user.password is not None:
# If the user is updating their password, we need to verify
# the old password
if user.old_password is None:
raise IllegalOperationError(
"The current password must be supplied when changing the "
"password."
)
with pass_change_limiter.limit_failed_requests(request):
auth_user = zen_store().get_auth_user(auth_context.user.id)
if not UserAuthModel.verify_password(
user.old_password, auth_user
):
raise IllegalOperationError(
"The current password is incorrect."
)
# Accept the password update
safe_user_update.password = user.password
updated_user = zen_store().update_user(
user_id=auth_context.user.id, user_update=safe_user_update
)
return dehydrate_response_model(updated_user)
update_user(user_name_or_id, user_update, request, auth_context=Security(oauth2_authentication))
Updates a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
user_update |
UserUpdate |
the user to use for the update. |
required |
request |
Request |
The request object. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponse |
The updated user. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
if the user tries change admin status, while not an admin, if the user tries to change the password of another user, or if the user tries to change their own password without providing the old password or providing an incorrect old password. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.put(
"/{user_name_or_id}",
response_model=UserResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_user(
user_name_or_id: Union[str, UUID],
user_update: UserUpdate,
request: Request,
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Updates a specific user.
Args:
user_name_or_id: Name or ID of the user.
user_update: the user to use for the update.
request: The request object.
auth_context: Authentication context.
Returns:
The updated user.
Raises:
IllegalOperationError: if the user tries change admin status,
while not an admin, if the user tries to change the password
of another user, or if the user tries to change their own
password without providing the old password or providing
an incorrect old password.
"""
user = zen_store().get_user(user_name_or_id)
# Use a separate object to compute the update that will be applied to
# the user to avoid giving the API requester direct control over the
# user attributes that are updated.
#
# Exclude attributes that cannot be updated through this endpoint:
#
# - activation_token
# - external_user_id
# - old_password
#
# Exclude things that are not always safe to update and need to be
# validated first:
#
# - admin
# - active
# - password
# - email_opted_in + email
#
safe_user_update = user_update.create_copy(
exclude={
"activation_token",
"external_user_id",
"is_admin",
"active",
"password",
"old_password",
"email_opted_in",
"email",
},
)
if user.id != auth_context.user.id:
verify_admin_status_if_no_rbac(
auth_context.user.is_admin, "update other user account"
)
# verify_permission_for_model(
# user,
# action=Action.UPDATE,
# )
# Validate a password change
if user_update.password is not None:
if user.id != auth_context.user.id:
raise IllegalOperationError(
"Users cannot change the password of other users. Use the "
"account deactivation and activation flow instead."
)
# If the user is updating their own password, we need to verify
# the old password
if user_update.old_password is None:
raise IllegalOperationError(
"The current password must be supplied when changing the "
"password."
)
with pass_change_limiter.limit_failed_requests(request):
auth_user = zen_store().get_auth_user(user_name_or_id)
if not UserAuthModel.verify_password(
user_update.old_password, auth_user
):
raise IllegalOperationError(
"The current password is incorrect."
)
# Accept the password update
safe_user_update.password = user_update.password
# Validate an admin status change
if (
user_update.is_admin is not None
and user.is_admin != user_update.is_admin
):
if user.id == auth_context.user.id:
raise IllegalOperationError(
"Cannot change the admin status of your own user account."
)
if (
user.id != auth_context.user.id
and not auth_context.user.is_admin
):
raise IllegalOperationError(
"Only admins are allowed to change the admin status of "
"other user accounts."
)
# Accept the admin status update
safe_user_update.is_admin = user_update.is_admin
# Validate an active status change
if (
user_update.active is not None
and user.active != user_update.active
):
if user.id == auth_context.user.id:
raise IllegalOperationError(
"Cannot change the active status of your own user account."
)
if (
user.id != auth_context.user.id
and not auth_context.user.is_admin
):
raise IllegalOperationError(
"Only admins are allowed to change the active status of "
"other user accounts."
)
# Accept the admin status update
safe_user_update.is_admin = user_update.is_admin
# Validate changes to private user account information
if (
user_update.email_opted_in is not None
or user_update.email is not None
):
if user.id != auth_context.user.id:
raise IllegalOperationError(
"Cannot change the private user account information for "
"another user account."
)
# Accept the private user account information update
if safe_user_update.email_opted_in is not None:
safe_user_update.email_opted_in = user_update.email_opted_in
safe_user_update.email = user_update.email
updated_user = zen_store().update_user(
user_id=user.id,
user_update=safe_user_update,
)
return dehydrate_response_model(updated_user)
webhook_endpoints
Endpoint definitions for webhooks.
get_body(request)
async
Get access to the raw body.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request |
required |
Returns:
Type | Description |
---|---|
bytes |
The raw request body. |
Source code in zenml/zen_server/routers/webhook_endpoints.py
async def get_body(request: Request) -> bytes:
"""Get access to the raw body.
Args:
request: The request
Returns:
The raw request body.
"""
return await request.body()
webhook(event_source_id, request, background_tasks, raw_body=Depends(get_body))
Webhook to receive events from external event sources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event_source_id |
UUID |
The event_source_id |
required |
request |
Request |
The request object |
required |
background_tasks |
BackgroundTasks |
Background task handler |
required |
raw_body |
bytes |
The raw request body |
Depends(get_body) |
Returns:
Type | Description |
---|---|
Dict[str, str] |
Static dict stating that event is received. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the Event Source does not exist. |
KeyError |
If no appropriate Plugin found in the plugin registry |
ValueError |
If the id of the Event Source is not actually a webhook event source |
WebhookInactiveError |
In case this webhook has been deactivated |
Source code in zenml/zen_server/routers/webhook_endpoints.py
@router.post(
"/{event_source_id}",
response_model=Dict[str, str],
)
@handle_exceptions
def webhook(
event_source_id: UUID,
request: Request,
background_tasks: BackgroundTasks,
raw_body: bytes = Depends(get_body),
) -> Dict[str, str]:
"""Webhook to receive events from external event sources.
Args:
event_source_id: The event_source_id
request: The request object
background_tasks: Background task handler
raw_body: The raw request body
Returns:
Static dict stating that event is received.
Raises:
AuthorizationException: If the Event Source does not exist.
KeyError: If no appropriate Plugin found in the plugin registry
ValueError: If the id of the Event Source is not actually a webhook event source
WebhookInactiveError: In case this webhook has been deactivated
"""
# Get the Event Source
try:
event_source = zen_store().get_event_source(event_source_id)
except KeyError:
logger.error(
f"Webhook HTTP request received for unknown event source "
f"'{event_source_id}'."
)
raise AuthorizationException( # TODO: Are we sure about this error message?
f"No webhook is registered at '{router.prefix}/{event_source_id}'"
)
if not event_source.is_active:
raise WebhookInactiveError(f"Webhook {event_source_id} is inactive.")
flavor = event_source.flavor
try:
plugin = plugin_flavor_registry().get_plugin(
name=flavor,
_type=PluginType.EVENT_SOURCE,
subtype=PluginSubType.WEBHOOK,
)
except KeyError:
logger.error(
f"Webhook HTTP request received for event source "
f"'{event_source_id}' and flavor {flavor} but no matching "
f"plugin was found."
)
raise KeyError(
f"No listener plugin found for event source {event_source_id}."
)
if not isinstance(plugin, BaseWebhookEventSourceHandler):
raise ValueError(
f"Event Source {event_source_id} is not a valid Webhook event "
"source!"
)
# Pass the raw event and headers to the plugin
background_tasks.add_task(
plugin.process_webhook_event,
event_source=event_source,
raw_body=raw_body,
headers=dict(request.headers.items()),
)
return {"status": "Event Received."}
workspaces_endpoints
Endpoint definitions for workspaces.
create_build(workspace_name_or_id, build, auth_context=Security(oauth2_authentication))
Creates a build.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
build |
PipelineBuildRequest |
Build to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
PipelineBuildResponse |
The created build. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the build does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_BUILDS,
response_model=PipelineBuildResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_build(
workspace_name_or_id: Union[str, UUID],
build: PipelineBuildRequest,
auth_context: AuthContext = Security(authorize),
) -> PipelineBuildResponse:
"""Creates a build.
Args:
workspace_name_or_id: Name or ID of the workspace.
build: Build to create.
auth_context: Authentication context.
Returns:
The created build.
Raises:
IllegalOperationError: If the workspace specified in the build
does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if build.workspace != workspace.id:
raise IllegalOperationError(
"Creating builds outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=build,
resource_type=ResourceType.PIPELINE_BUILD,
create_method=zen_store().create_build,
)
create_code_repository(workspace_name_or_id, code_repository, _=Security(oauth2_authentication))
Creates a code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
code_repository |
CodeRepositoryRequest |
Code repository to create. |
required |
Returns:
Type | Description |
---|---|
CodeRepositoryResponse |
The created code repository. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the code repository does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + CODE_REPOSITORIES,
response_model=CodeRepositoryResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_code_repository(
workspace_name_or_id: Union[str, UUID],
code_repository: CodeRepositoryRequest,
_: AuthContext = Security(authorize),
) -> CodeRepositoryResponse:
"""Creates a code repository.
Args:
workspace_name_or_id: Name or ID of the workspace.
code_repository: Code repository to create.
Returns:
The created code repository.
Raises:
IllegalOperationError: If the workspace or user specified in the
code repository does not match the current workspace or
authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if code_repository.workspace != workspace.id:
raise IllegalOperationError(
"Creating code repositories outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=code_repository,
resource_type=ResourceType.CODE_REPOSITORY,
create_method=zen_store().create_code_repository,
)
create_deployment(workspace_name_or_id, deployment, auth_context=Security(oauth2_authentication))
Creates a deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
deployment |
PipelineDeploymentRequest |
Deployment to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
PipelineDeploymentResponse |
The created deployment. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the deployment does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_DEPLOYMENTS,
response_model=PipelineDeploymentResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_deployment(
workspace_name_or_id: Union[str, UUID],
deployment: PipelineDeploymentRequest,
auth_context: AuthContext = Security(authorize),
) -> PipelineDeploymentResponse:
"""Creates a deployment.
Args:
workspace_name_or_id: Name or ID of the workspace.
deployment: Deployment to create.
auth_context: Authentication context.
Returns:
The created deployment.
Raises:
IllegalOperationError: If the workspace specified in the
deployment does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if deployment.workspace != workspace.id:
raise IllegalOperationError(
"Creating deployments outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=deployment,
resource_type=ResourceType.PIPELINE_DEPLOYMENT,
create_method=zen_store().create_deployment,
)
create_model(workspace_name_or_id, model, _=Security(oauth2_authentication))
Create a new model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model |
ModelRequest |
The model to create. |
required |
Returns:
Type | Description |
---|---|
ModelResponse |
The created model. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the model does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + MODELS,
response_model=ModelResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model(
workspace_name_or_id: Union[str, UUID],
model: ModelRequest,
_: AuthContext = Security(authorize),
) -> ModelResponse:
"""Create a new model.
Args:
workspace_name_or_id: Name or ID of the workspace.
model: The model to create.
Returns:
The created model.
Raises:
IllegalOperationError: If the workspace or user specified in the
model does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if model.workspace != workspace.id:
raise IllegalOperationError(
"Creating models outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=model,
resource_type=ResourceType.MODEL,
create_method=zen_store().create_model,
)
create_model_version(workspace_name_or_id, model_name_or_id, model_version, auth_context=Security(oauth2_authentication))
Create a new model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model. |
required |
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version |
ModelVersionRequest |
The model version to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ModelVersionResponse |
The created model version. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the model version does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODELS
+ "/{model_name_or_id}"
+ MODEL_VERSIONS,
response_model=ModelVersionResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version(
workspace_name_or_id: Union[str, UUID],
model_name_or_id: Union[str, UUID],
model_version: ModelVersionRequest,
auth_context: AuthContext = Security(authorize),
) -> ModelVersionResponse:
"""Create a new model version.
Args:
model_name_or_id: Name or ID of the model.
workspace_name_or_id: Name or ID of the workspace.
model_version: The model version to create.
auth_context: Authentication context.
Returns:
The created model version.
Raises:
IllegalOperationError: If the workspace specified in the
model version does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if model_version.workspace != workspace.id:
raise IllegalOperationError(
"Creating model versions outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=model_version,
resource_type=ResourceType.MODEL_VERSION,
create_method=zen_store().create_model_version,
)
create_pipeline(workspace_name_or_id, pipeline, _=Security(oauth2_authentication))
Creates a pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline |
PipelineRequest |
Pipeline to create. |
required |
Returns:
Type | Description |
---|---|
PipelineResponse |
The created pipeline. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the pipeline does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINES,
response_model=PipelineResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_pipeline(
workspace_name_or_id: Union[str, UUID],
pipeline: PipelineRequest,
_: AuthContext = Security(authorize),
) -> PipelineResponse:
"""Creates a pipeline.
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline: Pipeline to create.
Returns:
The created pipeline.
Raises:
IllegalOperationError: If the workspace or user specified in the pipeline
does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if pipeline.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipelines outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
# We limit pipeline namespaces, not pipeline versions
needs_usage_increment = (
ResourceType.PIPELINE in REPORTABLE_RESOURCES
and zen_store().count_pipelines(PipelineFilter(name=pipeline.name))
== 0
)
if needs_usage_increment:
check_entitlement(ResourceType.PIPELINE)
pipeline_response = verify_permissions_and_create_entity(
request_model=pipeline,
resource_type=ResourceType.PIPELINE,
create_method=zen_store().create_pipeline,
)
if needs_usage_increment:
report_usage(
resource_type=ResourceType.PIPELINE,
resource_id=pipeline_response.id,
)
return pipeline_response
create_pipeline_run(workspace_name_or_id, pipeline_run, _=Security(oauth2_authentication))
Creates a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline_run |
PipelineRunRequest |
Pipeline run to create. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
The created pipeline run. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the pipeline run does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUNS,
response_model=PipelineRunResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_pipeline_run(
workspace_name_or_id: Union[str, UUID],
pipeline_run: PipelineRunRequest,
_: AuthContext = Security(authorize),
) -> PipelineRunResponse:
"""Creates a pipeline run.
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline_run: Pipeline run to create.
Returns:
The created pipeline run.
Raises:
IllegalOperationError: If the workspace specified in the
pipeline run does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if pipeline_run.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipeline runs outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=pipeline_run,
resource_type=ResourceType.PIPELINE_RUN,
create_method=zen_store().create_run,
)
create_run_metadata(workspace_name_or_id, run_metadata, auth_context=Security(oauth2_authentication))
Creates run metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
run_metadata |
RunMetadataRequest |
The run metadata to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
None |
The created run metadata. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the run metadata does not match the current workspace or authenticated user. |
RuntimeError |
If the resource type is not supported. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUN_METADATA,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_run_metadata(
workspace_name_or_id: Union[str, UUID],
run_metadata: RunMetadataRequest,
auth_context: AuthContext = Security(authorize),
) -> None:
"""Creates run metadata.
Args:
workspace_name_or_id: Name or ID of the workspace.
run_metadata: The run metadata to create.
auth_context: Authentication context.
Returns:
The created run metadata.
Raises:
IllegalOperationError: If the workspace or user specified in the run
metadata does not match the current workspace or authenticated user.
RuntimeError: If the resource type is not supported.
"""
workspace = zen_store().get_workspace(run_metadata.workspace)
if run_metadata.workspace != workspace.id:
raise IllegalOperationError(
"Creating run metadata outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if run_metadata.user != auth_context.user.id:
raise IllegalOperationError(
"Creating run metadata for a user other than yourself "
"is not supported."
)
verify_models: List[Any] = []
for resource in run_metadata.resources:
if resource.type == MetadataResourceTypes.PIPELINE_RUN:
verify_models.append(zen_store().get_run(resource.id))
elif resource.type == MetadataResourceTypes.STEP_RUN:
verify_models.append(zen_store().get_run_step(resource.id))
elif resource.type == MetadataResourceTypes.ARTIFACT_VERSION:
verify_models.append(zen_store().get_artifact_version(resource.id))
elif resource.type == MetadataResourceTypes.MODEL_VERSION:
verify_models.append(zen_store().get_model_version(resource.id))
else:
raise RuntimeError(f"Unknown resource type: {resource.type}")
batch_verify_permissions_for_models(
models=verify_models,
action=Action.UPDATE,
)
verify_permission(
resource_type=ResourceType.RUN_METADATA, action=Action.CREATE
)
zen_store().create_run_metadata(run_metadata)
return None
create_run_template(workspace_name_or_id, run_template, _=Security(oauth2_authentication))
Create a run template.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
run_template |
RunTemplateRequest |
Run template to create. |
required |
Returns:
Type | Description |
---|---|
RunTemplateResponse |
The created run template. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the run template does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUN_TEMPLATES,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_run_template(
workspace_name_or_id: Union[str, UUID],
run_template: RunTemplateRequest,
_: AuthContext = Security(authorize),
) -> RunTemplateResponse:
"""Create a run template.
Args:
workspace_name_or_id: Name or ID of the workspace.
run_template: Run template to create.
Returns:
The created run template.
Raises:
IllegalOperationError: If the workspace specified in the
run template does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if run_template.workspace != workspace.id:
raise IllegalOperationError(
"Creating run templates outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=run_template,
resource_type=ResourceType.RUN_TEMPLATE,
create_method=zen_store().create_run_template,
)
create_schedule(workspace_name_or_id, schedule, auth_context=Security(oauth2_authentication))
Creates a schedule.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
schedule |
ScheduleRequest |
Schedule to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ScheduleResponse |
The created schedule. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the schedule does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SCHEDULES,
response_model=ScheduleResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_schedule(
workspace_name_or_id: Union[str, UUID],
schedule: ScheduleRequest,
auth_context: AuthContext = Security(authorize),
) -> ScheduleResponse:
"""Creates a schedule.
Args:
workspace_name_or_id: Name or ID of the workspace.
schedule: Schedule to create.
auth_context: Authentication context.
Returns:
The created schedule.
Raises:
IllegalOperationError: If the workspace or user specified in the
schedule does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if schedule.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipeline runs outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if schedule.user != auth_context.user.id:
raise IllegalOperationError(
"Creating pipeline runs for a user other than yourself "
"is not supported."
)
return zen_store().create_schedule(schedule=schedule)
create_secret(workspace_name_or_id, secret, _=Security(oauth2_authentication))
Creates a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
secret |
SecretRequest |
Secret to create. |
required |
Returns:
Type | Description |
---|---|
SecretResponse |
The created secret. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the secret does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SECRETS,
response_model=SecretResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_secret(
workspace_name_or_id: Union[str, UUID],
secret: SecretRequest,
_: AuthContext = Security(authorize),
) -> SecretResponse:
"""Creates a secret.
Args:
workspace_name_or_id: Name or ID of the workspace.
secret: Secret to create.
Returns:
The created secret.
Raises:
IllegalOperationError: If the workspace specified in the
secret does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if secret.workspace != workspace.id:
raise IllegalOperationError(
"Creating a secret outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=secret,
resource_type=ResourceType.SECRET,
create_method=zen_store().create_secret,
)
create_service(workspace_name_or_id, service, _=Security(oauth2_authentication))
Create a new service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
service |
ServiceRequest |
The service to create. |
required |
Returns:
Type | Description |
---|---|
ServiceResponse |
The created service. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the model does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SERVICES,
response_model=ServiceResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_service(
workspace_name_or_id: Union[str, UUID],
service: ServiceRequest,
_: AuthContext = Security(authorize),
) -> ServiceResponse:
"""Create a new service.
Args:
workspace_name_or_id: Name or ID of the workspace.
service: The service to create.
Returns:
The created service.
Raises:
IllegalOperationError: If the workspace or user specified in the
model does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if service.workspace != workspace.id:
raise IllegalOperationError(
"Creating models outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=service,
resource_type=ResourceType.SERVICE,
create_method=zen_store().create_service,
)
create_service_connector(workspace_name_or_id, connector, _=Security(oauth2_authentication))
Creates a service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
connector |
ServiceConnectorRequest |
Service connector to register. |
required |
Returns:
Type | Description |
---|---|
ServiceConnectorResponse |
The created service connector. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the service connector does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SERVICE_CONNECTORS,
response_model=ServiceConnectorResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_service_connector(
workspace_name_or_id: Union[str, UUID],
connector: ServiceConnectorRequest,
_: AuthContext = Security(authorize),
) -> ServiceConnectorResponse:
"""Creates a service connector.
Args:
workspace_name_or_id: Name or ID of the workspace.
connector: Service connector to register.
Returns:
The created service connector.
Raises:
IllegalOperationError: If the workspace or user specified in the service
connector does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if connector.workspace != workspace.id:
raise IllegalOperationError(
"Creating connectors outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
return verify_permissions_and_create_entity(
request_model=connector,
resource_type=ResourceType.SERVICE_CONNECTOR,
create_method=zen_store().create_service_connector,
)
create_stack(workspace_name_or_id, stack, auth_context=Security(oauth2_authentication))
Creates a stack for a particular workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
stack |
StackRequest |
Stack to register. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
StackResponse |
The created stack. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + STACKS,
response_model=StackResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_stack(
workspace_name_or_id: Union[str, UUID],
stack: StackRequest,
auth_context: AuthContext = Security(authorize),
) -> StackResponse:
"""Creates a stack for a particular workspace.
Args:
workspace_name_or_id: Name or ID of the workspace.
stack: Stack to register.
auth_context: Authentication context.
Returns:
The created stack.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
# Check the service connector creation
is_connector_create_needed = False
for connector_id_or_info in stack.service_connectors:
if isinstance(connector_id_or_info, UUID):
service_connector = zen_store().get_service_connector(
connector_id_or_info, hydrate=False
)
verify_permission_for_model(
model=service_connector, action=Action.READ
)
else:
is_connector_create_needed = True
# Check the component creation
if is_connector_create_needed:
verify_permission(
resource_type=ResourceType.SERVICE_CONNECTOR, action=Action.CREATE
)
is_component_create_needed = False
for components in stack.components.values():
for component_id_or_info in components:
if isinstance(component_id_or_info, UUID):
component = zen_store().get_stack_component(
component_id_or_info, hydrate=False
)
verify_permission_for_model(
model=component, action=Action.READ
)
else:
is_component_create_needed = True
if is_component_create_needed:
verify_permission(
resource_type=ResourceType.STACK_COMPONENT,
action=Action.CREATE,
)
# Check the stack creation
verify_permission(resource_type=ResourceType.STACK, action=Action.CREATE)
stack.user = auth_context.user.id
stack.workspace = workspace.id
return zen_store().create_stack(stack)
create_stack_component(workspace_name_or_id, component, _=Security(oauth2_authentication))
Creates a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
component |
ComponentRequest |
Stack component to register. |
required |
Returns:
Type | Description |
---|---|
ComponentResponse |
The created stack component. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace specified in the stack component does not match the current workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + STACK_COMPONENTS,
response_model=ComponentResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_stack_component(
workspace_name_or_id: Union[str, UUID],
component: ComponentRequest,
_: AuthContext = Security(authorize),
) -> ComponentResponse:
"""Creates a stack component.
Args:
workspace_name_or_id: Name or ID of the workspace.
component: Stack component to register.
Returns:
The created stack component.
Raises:
IllegalOperationError: If the workspace specified in the stack
component does not match the current workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if component.workspace != workspace.id:
raise IllegalOperationError(
"Creating components outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if component.connector:
service_connector = zen_store().get_service_connector(
component.connector
)
verify_permission_for_model(service_connector, action=Action.READ)
from zenml.stack.utils import validate_stack_component_config
validate_stack_component_config(
configuration_dict=component.configuration,
flavor=component.flavor,
component_type=component.type,
zen_store=zen_store(),
# We allow custom flavors to fail import on the server side.
validate_custom_flavors=False,
)
return verify_permissions_and_create_entity(
request_model=component,
resource_type=ResourceType.STACK_COMPONENT,
create_method=zen_store().create_stack_component,
)
create_workspace(workspace_request, _=Security(oauth2_authentication))
Creates a workspace based on the requestBody.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_request |
WorkspaceRequest |
Workspace to create. |
required |
Returns:
Type | Description |
---|---|
WorkspaceResponse |
The created workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_workspace(
workspace_request: WorkspaceRequest,
_: AuthContext = Security(authorize),
) -> WorkspaceResponse:
"""Creates a workspace based on the requestBody.
# noqa: DAR401
Args:
workspace_request: Workspace to create.
Returns:
The created workspace.
"""
workspace = zen_store().create_workspace(workspace_request)
return dehydrate_response_model(workspace)
delete_workspace(workspace_name_or_id, _=Security(oauth2_authentication))
Deletes a workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.delete(
WORKSPACES + "/{workspace_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_workspace(
workspace_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize),
) -> None:
"""Deletes a workspace.
Args:
workspace_name_or_id: Name or ID of the workspace.
"""
zen_store().delete_workspace(workspace_name_or_id)
get_or_create_pipeline_run(workspace_name_or_id, pipeline_run, auth_context=Security(oauth2_authentication))
Get or create a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline_run |
PipelineRunRequest |
Pipeline run to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Tuple[zenml.models.v2.core.pipeline_run.PipelineRunResponse, bool] |
The pipeline run and a boolean indicating whether the run was created or not. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the pipeline run does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUNS + GET_OR_CREATE,
response_model=Tuple[PipelineRunResponse, bool],
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def get_or_create_pipeline_run(
workspace_name_or_id: Union[str, UUID],
pipeline_run: PipelineRunRequest,
auth_context: AuthContext = Security(authorize),
) -> Tuple[PipelineRunResponse, bool]:
"""Get or create a pipeline run.
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline_run: Pipeline run to create.
auth_context: Authentication context.
Returns:
The pipeline run and a boolean indicating whether the run was created
or not.
Raises:
IllegalOperationError: If the workspace or user specified in the
pipeline run does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if pipeline_run.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipeline runs outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if pipeline_run.user != auth_context.user.id:
raise IllegalOperationError(
"Creating pipeline runs for a user other than yourself "
"is not supported."
)
def _pre_creation_hook() -> None:
verify_permission(
resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE
)
check_entitlement(resource_type=ResourceType.PIPELINE_RUN)
run, created = zen_store().get_or_create_run(
pipeline_run=pipeline_run, pre_creation_hook=_pre_creation_hook
)
if created:
report_usage(
resource_type=ResourceType.PIPELINE_RUN, resource_id=run.id
)
else:
verify_permission_for_model(run, action=Action.READ)
return run, created
get_workspace(workspace_name_or_id, hydrate=True, _=Security(oauth2_authentication))
Get a workspace for given name.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
True |
Returns:
Type | Description |
---|---|
WorkspaceResponse |
The requested workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}",
response_model=WorkspaceResponse,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_workspace(
workspace_name_or_id: Union[str, UUID],
hydrate: bool = True,
_: AuthContext = Security(authorize),
) -> WorkspaceResponse:
"""Get a workspace for given name.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested workspace.
"""
workspace = zen_store().get_workspace(
workspace_name_or_id, hydrate=hydrate
)
return dehydrate_response_model(workspace)
get_workspace_statistics(workspace_name_or_id, auth_context=Security(oauth2_authentication))
Gets statistics of a workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace to get statistics for. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Dict[str, int] |
All pipelines within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + STATISTICS,
response_model=Dict[str, int],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_workspace_statistics(
workspace_name_or_id: Union[str, UUID],
auth_context: AuthContext = Security(authorize),
) -> Dict[str, int]:
"""Gets statistics of a workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace to get statistics for.
auth_context: Authentication context.
Returns:
All pipelines within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
user_id = auth_context.user.id
component_filter = ComponentFilter(workspace_id=workspace.id)
component_filter.configure_rbac(
authenticated_user_id=user_id,
id=get_allowed_resource_ids(
resource_type=ResourceType.STACK_COMPONENT
),
)
stack_filter = StackFilter(workspace_id=workspace.id)
stack_filter.configure_rbac(
authenticated_user_id=user_id,
id=get_allowed_resource_ids(resource_type=ResourceType.STACK),
)
run_filter = PipelineRunFilter(workspace_id=workspace.id)
run_filter.configure_rbac(
authenticated_user_id=user_id,
id=get_allowed_resource_ids(resource_type=ResourceType.PIPELINE_RUN),
)
pipeline_filter = PipelineFilter(workspace_id=workspace.id)
pipeline_filter.configure_rbac(
authenticated_user_id=user_id,
id=get_allowed_resource_ids(resource_type=ResourceType.PIPELINE),
)
return {
"stacks": zen_store().count_stacks(filter_model=stack_filter),
"components": zen_store().count_stack_components(
filter_model=component_filter
),
"pipelines": zen_store().count_pipelines(filter_model=pipeline_filter),
"runs": zen_store().count_runs(filter_model=run_filter),
}
list_runs(workspace_name_or_id, runs_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get pipeline runs according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
runs_filter_model |
PipelineRunFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineRunResponse] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + RUNS,
response_model=Page[PipelineRunResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_runs(
workspace_name_or_id: Union[str, UUID],
runs_filter_model: PipelineRunFilter = Depends(
make_dependable(PipelineRunFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineRunResponse]:
"""Get pipeline runs according to query filters.
Args:
workspace_name_or_id: Name or ID of the workspace.
runs_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The pipeline runs according to query filters.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
runs_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=runs_filter_model,
resource_type=ResourceType.PIPELINE_RUN,
list_method=zen_store().list_runs,
hydrate=hydrate,
)
list_service_connector_resources(workspace_name_or_id, connector_type=None, resource_type=None, resource_id=None, auth_context=Security(oauth2_authentication))
List resources that can be accessed by service connectors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
connector_type |
Optional[str] |
the service connector type identifier to filter by. |
None |
resource_type |
Optional[str] |
the resource type identifier to filter by. |
None |
resource_id |
Optional[str] |
the resource identifier to filter by. |
None |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
List[zenml.models.v2.misc.service_connector_type.ServiceConnectorResourcesModel] |
The matching list of resources that available service connectors have access to. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES
+ "/{workspace_name_or_id}"
+ SERVICE_CONNECTORS
+ SERVICE_CONNECTOR_RESOURCES,
response_model=List[ServiceConnectorResourcesModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_connector_resources(
workspace_name_or_id: Union[str, UUID],
connector_type: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
auth_context: AuthContext = Security(authorize),
) -> List[ServiceConnectorResourcesModel]:
"""List resources that can be accessed by service connectors.
Args:
workspace_name_or_id: Name or ID of the workspace.
connector_type: the service connector type identifier to filter by.
resource_type: the resource type identifier to filter by.
resource_id: the resource identifier to filter by.
auth_context: Authentication context.
Returns:
The matching list of resources that available service
connectors have access to.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
filter_model = ServiceConnectorFilter(
connector_type=connector_type,
resource_type=resource_type,
)
filter_model.set_scope_workspace(workspace.id)
allowed_ids = get_allowed_resource_ids(
resource_type=ResourceType.SERVICE_CONNECTOR
)
filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id, id=allowed_ids
)
return zen_store().list_service_connector_resources(
workspace_name_or_id=workspace_name_or_id,
connector_type=connector_type,
resource_type=resource_type,
resource_id=resource_id,
filter_model=filter_model,
)
list_workspace_builds(workspace_name_or_id, build_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets builds defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
build_filter_model |
PipelineBuildFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineBuildResponse] |
All builds within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_BUILDS,
response_model=Page[PipelineBuildResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_builds(
workspace_name_or_id: Union[str, UUID],
build_filter_model: PipelineBuildFilter = Depends(
make_dependable(PipelineBuildFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineBuildResponse]:
"""Gets builds defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
build_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All builds within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
build_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=build_filter_model,
resource_type=ResourceType.PIPELINE_BUILD,
list_method=zen_store().list_builds,
hydrate=hydrate,
)
list_workspace_code_repositories(workspace_name_or_id, filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets code repositories defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
filter_model |
CodeRepositoryFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[CodeRepositoryResponse] |
All code repositories within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + CODE_REPOSITORIES,
response_model=Page[CodeRepositoryResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_code_repositories(
workspace_name_or_id: Union[str, UUID],
filter_model: CodeRepositoryFilter = Depends(
make_dependable(CodeRepositoryFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[CodeRepositoryResponse]:
"""Gets code repositories defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All code repositories within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=filter_model,
resource_type=ResourceType.CODE_REPOSITORY,
list_method=zen_store().list_code_repositories,
hydrate=hydrate,
)
list_workspace_deployments(workspace_name_or_id, deployment_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets deployments defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
deployment_filter_model |
PipelineDeploymentFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineDeploymentResponse] |
All deployments within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_DEPLOYMENTS,
response_model=Page[PipelineDeploymentResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_deployments(
workspace_name_or_id: Union[str, UUID],
deployment_filter_model: PipelineDeploymentFilter = Depends(
make_dependable(PipelineDeploymentFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineDeploymentResponse]:
"""Gets deployments defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
deployment_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All deployments within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
deployment_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=deployment_filter_model,
resource_type=ResourceType.PIPELINE_DEPLOYMENT,
list_method=zen_store().list_deployments,
hydrate=hydrate,
)
list_workspace_pipelines(workspace_name_or_id, pipeline_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Gets pipelines defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline_filter_model |
PipelineFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[PipelineResponse] |
All pipelines within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINES,
response_model=Page[PipelineResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_pipelines(
workspace_name_or_id: Union[str, UUID],
pipeline_filter_model: PipelineFilter = Depends(
make_dependable(PipelineFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[PipelineResponse]:
"""Gets pipelines defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All pipelines within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
pipeline_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=pipeline_filter_model,
resource_type=ResourceType.PIPELINE,
list_method=zen_store().list_pipelines,
hydrate=hydrate,
)
list_workspace_run_templates(workspace_name_or_id, filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get a page of run templates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
filter_model |
RunTemplateFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[RunTemplateResponse] |
Page of run templates. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + RUN_TEMPLATES,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_run_templates(
workspace_name_or_id: Union[str, UUID],
filter_model: RunTemplateFilter = Depends(
make_dependable(RunTemplateFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[RunTemplateResponse]:
"""Get a page of run templates.
Args:
workspace_name_or_id: Name or ID of the workspace.
filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
Page of run templates.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=filter_model,
resource_type=ResourceType.RUN_TEMPLATE,
list_method=zen_store().list_run_templates,
hydrate=hydrate,
)
list_workspace_service_connectors(workspace_name_or_id, connector_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
List service connectors that are part of a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
connector_filter_model |
ServiceConnectorFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ServiceConnectorResponse] |
All service connectors part of the specified workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + SERVICE_CONNECTORS,
response_model=Page[ServiceConnectorResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_service_connectors(
workspace_name_or_id: Union[str, UUID],
connector_filter_model: ServiceConnectorFilter = Depends(
make_dependable(ServiceConnectorFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ServiceConnectorResponse]:
"""List service connectors that are part of a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
connector_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All service connectors part of the specified workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
connector_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=connector_filter_model,
resource_type=ResourceType.SERVICE_CONNECTOR,
list_method=zen_store().list_service_connectors,
hydrate=hydrate,
)
list_workspace_stack_components(workspace_name_or_id, component_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
List stack components that are part of a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
component_filter_model |
ComponentFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[ComponentResponse] |
All stack components part of the specified workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + STACK_COMPONENTS,
response_model=Page[ComponentResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_stack_components(
workspace_name_or_id: Union[str, UUID],
component_filter_model: ComponentFilter = Depends(
make_dependable(ComponentFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[ComponentResponse]:
"""List stack components that are part of a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
component_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All stack components part of the specified workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
component_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=component_filter_model,
resource_type=ResourceType.STACK_COMPONENT,
list_method=zen_store().list_stack_components,
hydrate=hydrate,
)
list_workspace_stacks(workspace_name_or_id, stack_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Get stacks that are part of a specific workspace for the user.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
stack_filter_model |
StackFilter |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[StackResponse] |
All stacks part of the specified workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + STACKS,
response_model=Page[StackResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_stacks(
workspace_name_or_id: Union[str, UUID],
stack_filter_model: StackFilter = Depends(make_dependable(StackFilter)),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[StackResponse]:
"""Get stacks that are part of a specific workspace for the user.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
stack_filter_model: Filter model used for pagination, sorting,
filtering.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
All stacks part of the specified workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
stack_filter_model.set_scope_workspace(workspace.id)
return verify_permissions_and_list_entities(
filter_model=stack_filter_model,
resource_type=ResourceType.STACK,
list_method=zen_store().list_stacks,
hydrate=hydrate,
)
list_workspaces(workspace_filter_model=Depends(init_cls_and_handle_errors), hydrate=False, _=Security(oauth2_authentication))
Lists all workspaces in the organization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_filter_model |
WorkspaceFilter |
Filter model used for pagination, sorting, filtering, |
Depends(init_cls_and_handle_errors) |
hydrate |
bool |
Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. |
False |
Returns:
Type | Description |
---|---|
Page[WorkspaceResponse] |
A list of workspaces. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES,
response_model=Page[WorkspaceResponse],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspaces(
workspace_filter_model: WorkspaceFilter = Depends(
make_dependable(WorkspaceFilter)
),
hydrate: bool = False,
_: AuthContext = Security(authorize),
) -> Page[WorkspaceResponse]:
"""Lists all workspaces in the organization.
Args:
workspace_filter_model: Filter model used for pagination, sorting,
filtering,
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of workspaces.
"""
workspaces = zen_store().list_workspaces(
workspace_filter_model, hydrate=hydrate
)
return dehydrate_page(workspaces)
update_workspace(workspace_name_or_id, workspace_update, _=Security(oauth2_authentication))
Get a workspace for given name.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
UUID |
Name or ID of the workspace to update. |
required |
workspace_update |
WorkspaceUpdate |
the workspace to use to update |
required |
Returns:
Type | Description |
---|---|
WorkspaceResponse |
The updated workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.put(
WORKSPACES + "/{workspace_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_workspace(
workspace_name_or_id: UUID,
workspace_update: WorkspaceUpdate,
_: AuthContext = Security(authorize),
) -> WorkspaceResponse:
"""Get a workspace for given name.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace to update.
workspace_update: the workspace to use to update
Returns:
The updated workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id, hydrate=False)
updated_workspace = zen_store().update_workspace(
workspace_id=workspace.id, workspace_update=workspace_update
)
return dehydrate_response_model(updated_workspace)
secure_headers
Secure headers for the ZenML Server.
initialize_secure_headers()
Initialize the secure headers component.
Source code in zenml/zen_server/secure_headers.py
def initialize_secure_headers() -> None:
"""Initialize the secure headers component."""
global _secure_headers
config = server_config()
# For each of the secure headers supported by the `secure` library, we
# check if the corresponding configuration is set in the server
# configuration:
#
# - if set to `True`, we use the default value for the header
# - if set to a string, we use the string as the value for the header
# - if set to `False`, we don't set the header
server: Optional[secure.Server] = None
if config.secure_headers_server:
server = secure.Server()
if isinstance(config.secure_headers_server, str):
server.set(config.secure_headers_server)
else:
server.set(str(config.deployment_id))
hsts: Optional[secure.StrictTransportSecurity] = None
if config.secure_headers_hsts:
hsts = secure.StrictTransportSecurity()
if isinstance(config.secure_headers_hsts, str):
hsts.set(config.secure_headers_hsts)
xfo: Optional[secure.XFrameOptions] = None
if config.secure_headers_xfo:
xfo = secure.XFrameOptions()
if isinstance(config.secure_headers_xfo, str):
xfo.set(config.secure_headers_xfo)
xxp: Optional[secure.XXSSProtection] = None
if config.secure_headers_xxp:
xxp = secure.XXSSProtection()
if isinstance(config.secure_headers_xxp, str):
xxp.set(config.secure_headers_xxp)
csp: Optional[secure.ContentSecurityPolicy] = None
if config.secure_headers_csp:
csp = secure.ContentSecurityPolicy()
if isinstance(config.secure_headers_csp, str):
csp.set(config.secure_headers_csp)
content: Optional[secure.XContentTypeOptions] = None
if config.secure_headers_content:
content = secure.XContentTypeOptions()
if isinstance(config.secure_headers_content, str):
content.set(config.secure_headers_content)
referrer: Optional[secure.ReferrerPolicy] = None
if config.secure_headers_referrer:
referrer = secure.ReferrerPolicy()
if isinstance(config.secure_headers_referrer, str):
referrer.set(config.secure_headers_referrer)
cache: Optional[secure.CacheControl] = None
if config.secure_headers_cache:
cache = secure.CacheControl()
if isinstance(config.secure_headers_cache, str):
cache.set(config.secure_headers_cache)
permissions: Optional[secure.PermissionsPolicy] = None
if config.secure_headers_permissions:
permissions = secure.PermissionsPolicy()
if isinstance(config.secure_headers_permissions, str):
permissions.value = config.secure_headers_permissions
_secure_headers = secure.Secure(
server=server,
hsts=hsts,
xfo=xfo,
xxp=xxp,
csp=csp,
content=content,
referrer=referrer,
cache=cache,
permissions=permissions,
)
secure_headers()
Return the secure headers component.
Returns:
Type | Description |
---|---|
Secure |
The secure headers component. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the secure headers component is not initialized. |
Source code in zenml/zen_server/secure_headers.py
def secure_headers() -> secure.Secure:
"""Return the secure headers component.
Returns:
The secure headers component.
Raises:
RuntimeError: If the secure headers component is not initialized.
"""
global _secure_headers
if _secure_headers is None:
raise RuntimeError("Secure headers component not initialized")
return _secure_headers
template_execution
special
runner_entrypoint_configuration
Runner entrypoint configuration.
RunnerEntrypointConfiguration (BaseEntrypointConfiguration)
Runner entrypoint configuration.
Source code in zenml/zen_server/template_execution/runner_entrypoint_configuration.py
class RunnerEntrypointConfiguration(BaseEntrypointConfiguration):
"""Runner entrypoint configuration."""
def run(self) -> None:
"""Run the entrypoint configuration.
This method runs the pipeline defined by the deployment given as input
to the entrypoint configuration.
"""
deployment = self.load_deployment()
stack = Client().active_stack
assert deployment.stack and stack.id == deployment.stack.id
placeholder_run = get_placeholder_run(deployment_id=deployment.id)
deploy_pipeline(
deployment=deployment,
stack=stack,
placeholder_run=placeholder_run,
)
run(self)
Run the entrypoint configuration.
This method runs the pipeline defined by the deployment given as input to the entrypoint configuration.
Source code in zenml/zen_server/template_execution/runner_entrypoint_configuration.py
def run(self) -> None:
"""Run the entrypoint configuration.
This method runs the pipeline defined by the deployment given as input
to the entrypoint configuration.
"""
deployment = self.load_deployment()
stack = Client().active_stack
assert deployment.stack and stack.id == deployment.stack.id
placeholder_run = get_placeholder_run(deployment_id=deployment.id)
deploy_pipeline(
deployment=deployment,
stack=stack,
placeholder_run=placeholder_run,
)
utils
Utility functions to run a pipeline from the server.
deployment_request_from_template(template, config, user_id)
Generate a deployment request from a template.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
template |
RunTemplateResponse |
The template from which to create the deployment request. |
required |
config |
PipelineRunConfiguration |
The run configuration. |
required |
user_id |
UUID |
ID of the user that is trying to run the template. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the run configuration is missing step parameters. |
Returns:
Type | Description |
---|---|
PipelineDeploymentRequest |
The generated deployment request. |
Source code in zenml/zen_server/template_execution/utils.py
def deployment_request_from_template(
template: RunTemplateResponse,
config: PipelineRunConfiguration,
user_id: UUID,
) -> "PipelineDeploymentRequest":
"""Generate a deployment request from a template.
Args:
template: The template from which to create the deployment request.
config: The run configuration.
user_id: ID of the user that is trying to run the template.
Raises:
ValueError: If the run configuration is missing step parameters.
Returns:
The generated deployment request.
"""
deployment = template.source_deployment
assert deployment
pipeline_configuration = PipelineConfiguration(
**config.model_dump(
include=set(PipelineConfiguration.model_fields),
exclude={"name", "parameters"},
),
name=deployment.pipeline_configuration.name,
parameters=deployment.pipeline_configuration.parameters,
)
step_config_dict_base = pipeline_configuration.model_dump(
exclude={"name", "parameters", "tags"}
)
steps = {}
for invocation_id, step in deployment.step_configurations.items():
step_config_dict = {
**copy.deepcopy(step_config_dict_base),
**step.config.model_dump(
# TODO: Maybe we need to make some of these configurable via
# yaml as well, e.g. the lazy loaders?
include={
"name",
"caching_parameters",
"external_input_artifacts",
"model_artifacts_or_metadata",
"client_lazy_loaders",
"substitutions",
"outputs",
}
),
}
required_parameters = set(step.config.parameters)
configured_parameters = set()
if update := config.steps.get(invocation_id):
update_dict = update.model_dump()
# Get rid of deprecated name to prevent overriding the step name
# with `None`.
update_dict.pop("name", None)
configured_parameters = set(update.parameters)
step_config_dict = dict_utils.recursive_update(
step_config_dict, update=update_dict
)
if configured_parameters != required_parameters:
missing_parameters = required_parameters - configured_parameters
raise ValueError(
"Run configuration is missing missing the following required "
f"parameters for step {step.config.name}: {missing_parameters}."
)
step_config = StepConfiguration.model_validate(step_config_dict)
steps[invocation_id] = Step(spec=step.spec, config=step_config)
code_reference_request = None
if deployment.code_reference:
code_reference_request = CodeReferenceRequest(
commit=deployment.code_reference.commit,
subdirectory=deployment.code_reference.subdirectory,
code_repository=deployment.code_reference.code_repository.id,
)
zenml_version = zen_store().get_store_info().version
assert deployment.stack
assert deployment.build
deployment_request = PipelineDeploymentRequest(
user=user_id,
workspace=deployment.workspace.id,
run_name_template=config.run_name
or get_default_run_name(pipeline_name=pipeline_configuration.name),
pipeline_configuration=pipeline_configuration,
step_configurations=steps,
client_environment={},
client_version=zenml_version,
server_version=zenml_version,
stack=deployment.stack.id,
pipeline=deployment.pipeline.id if deployment.pipeline else None,
build=deployment.build.id,
schedule=None,
code_reference=code_reference_request,
code_path=deployment.code_path,
template=template.id,
pipeline_version_hash=deployment.pipeline_version_hash,
pipeline_spec=deployment.pipeline_spec,
)
return deployment_request
ensure_async_orchestrator(deployment, stack)
Ensures the orchestrator is configured to run async.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentRequest |
Deployment request in which the orchestrator configuration should be updated to ensure the orchestrator is running async. |
required |
stack |
StackResponse |
The stack on which the deployment will run. |
required |
Source code in zenml/zen_server/template_execution/utils.py
def ensure_async_orchestrator(
deployment: PipelineDeploymentRequest, stack: StackResponse
) -> None:
"""Ensures the orchestrator is configured to run async.
Args:
deployment: Deployment request in which the orchestrator
configuration should be updated to ensure the orchestrator is
running async.
stack: The stack on which the deployment will run.
"""
orchestrator = stack.components[StackComponentType.ORCHESTRATOR][0]
flavors = zen_store().list_flavors(
FlavorFilter(name=orchestrator.flavor_name, type=orchestrator.type)
)
flavor = Flavor.from_model(flavors[0])
if "synchronous" in flavor.config_class.model_fields:
key = settings_utils.get_flavor_setting_key(flavor)
if settings := deployment.pipeline_configuration.settings.get(key):
settings_dict = settings.model_dump()
else:
settings_dict = {}
settings_dict["synchronous"] = False
deployment.pipeline_configuration.settings[key] = (
BaseSettings.model_validate(settings_dict)
)
generate_dockerfile(pypi_requirements, apt_packages, zenml_version, python_version)
Generate a Dockerfile that installs the requirements.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pypi_requirements |
List[str] |
The PyPI requirements to install. |
required |
apt_packages |
List[str] |
The APT packages to install. |
required |
zenml_version |
str |
The ZenML version to use as parent image. |
required |
python_version |
str |
The Python version to use as parent image. |
required |
Returns:
Type | Description |
---|---|
str |
The Dockerfile. |
Source code in zenml/zen_server/template_execution/utils.py
def generate_dockerfile(
pypi_requirements: List[str],
apt_packages: List[str],
zenml_version: str,
python_version: str,
) -> str:
"""Generate a Dockerfile that installs the requirements.
Args:
pypi_requirements: The PyPI requirements to install.
apt_packages: The APT packages to install.
zenml_version: The ZenML version to use as parent image.
python_version: The Python version to use as parent image.
Returns:
The Dockerfile.
"""
parent_image = f"zenmldocker/zenml:{zenml_version}-py{python_version}"
lines = [f"FROM {parent_image}"]
if apt_packages:
apt_packages_string = " ".join(f"'{p}'" for p in apt_packages)
lines.append(
"RUN apt-get update && apt-get install -y "
f"--no-install-recommends {apt_packages_string}"
)
if pypi_requirements:
pypi_requirements_string = " ".join(
[f"'{r}'" for r in pypi_requirements]
)
lines.append(
f"RUN pip install --default-timeout=60 --no-cache-dir "
f"{pypi_requirements_string}"
)
return "\n".join(lines)
generate_image_hash(dockerfile)
Generate a hash of the Dockerfile.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dockerfile |
str |
The Dockerfile for which to generate the hash. |
required |
Returns:
Type | Description |
---|---|
str |
The hash of the Dockerfile. |
Source code in zenml/zen_server/template_execution/utils.py
def generate_image_hash(dockerfile: str) -> str:
"""Generate a hash of the Dockerfile.
Args:
dockerfile: The Dockerfile for which to generate the hash.
Returns:
The hash of the Dockerfile.
"""
hash_ = hashlib.md5() # nosec
# Uncomment this line when developing to guarantee a new docker image gets
# built after restarting the server
# hash_.update(f"{os.getpid()}".encode())
hash_.update(dockerfile.encode())
return hash_.hexdigest()
get_pipeline_run_analytics_metadata(deployment, stack, template_id, run_id)
Get metadata for the pipeline run analytics event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The deployment of the run. |
required |
stack |
StackResponse |
The stack on which the run will happen. |
required |
template_id |
UUID |
ID of the template from which the run was started. |
required |
run_id |
UUID |
ID of the run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The analytics metadata. |
Source code in zenml/zen_server/template_execution/utils.py
def get_pipeline_run_analytics_metadata(
deployment: "PipelineDeploymentResponse",
stack: StackResponse,
template_id: UUID,
run_id: UUID,
) -> Dict[str, Any]:
"""Get metadata for the pipeline run analytics event.
Args:
deployment: The deployment of the run.
stack: The stack on which the run will happen.
template_id: ID of the template from which the run was started.
run_id: ID of the run.
Returns:
The analytics metadata.
"""
custom_materializer = False
for step in deployment.step_configurations.values():
for output in step.config.outputs.values():
for source in output.materializer_source:
if not source.is_internal:
custom_materializer = True
assert deployment.user
stack_creator = stack.user
own_stack = stack_creator and stack_creator.id == deployment.user.id
stack_metadata = {
component_type.value: component_list[0].flavor_name
for component_type, component_list in stack.components.items()
}
return {
"store_type": "rest", # This method is called from within a REST endpoint
**stack_metadata,
"total_steps": len(deployment.step_configurations),
"schedule": deployment.schedule is not None,
"custom_materializer": custom_materializer,
"own_stack": own_stack,
"pipeline_run_id": str(run_id),
"template_id": str(template_id),
}
run_template(template, auth_context, background_tasks=None, run_config=None)
Run a pipeline from a template.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
template |
RunTemplateResponse |
The template to run. |
required |
auth_context |
AuthContext |
Authentication context. |
required |
background_tasks |
Optional[fastapi.background.BackgroundTasks] |
Background tasks. |
None |
run_config |
Optional[zenml.config.pipeline_run_configuration.PipelineRunConfiguration] |
The run configuration. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If the template can not be run. |
RuntimeError |
If the server URL is not set in the server configuration. |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
ID of the new pipeline run. |
Source code in zenml/zen_server/template_execution/utils.py
def run_template(
template: RunTemplateResponse,
auth_context: AuthContext,
background_tasks: Optional[BackgroundTasks] = None,
run_config: Optional[PipelineRunConfiguration] = None,
) -> PipelineRunResponse:
"""Run a pipeline from a template.
Args:
template: The template to run.
auth_context: Authentication context.
background_tasks: Background tasks.
run_config: The run configuration.
Raises:
ValueError: If the template can not be run.
RuntimeError: If the server URL is not set in the server configuration.
Returns:
ID of the new pipeline run.
"""
if not template.runnable:
raise ValueError(
"This template can not be run because its associated deployment, "
"stack or build have been deleted."
)
# Guaranteed by the `runnable` check above
build = template.build
assert build
stack = build.stack
assert stack
if build.stack_checksum and build.stack_checksum != compute_stack_checksum(
stack=stack
):
raise ValueError(
f"The stack {stack.name} has been updated since it was used for "
"the run that is the base for this template. This means the Docker "
"images associated with this template most likely do not contain "
"the necessary requirements. Please create a new template from a "
"recent run on this stack."
)
validate_stack_is_runnable_from_server(zen_store=zen_store(), stack=stack)
if run_config:
validate_run_config_is_runnable_from_server(run_config)
deployment_request = deployment_request_from_template(
template=template,
config=run_config or PipelineRunConfiguration(),
user_id=auth_context.user.id,
)
ensure_async_orchestrator(deployment=deployment_request, stack=stack)
new_deployment = zen_store().create_deployment(deployment_request)
server_url = server_config().server_url
if not server_url:
raise RuntimeError(
"The server URL is not set in the server configuration."
)
assert build.zenml_version
zenml_version = build.zenml_version
placeholder_run = create_placeholder_run(deployment=new_deployment)
assert placeholder_run
# We create an API token scoped to the pipeline run that never expires
api_token = generate_access_token(
user_id=auth_context.user.id,
pipeline_run_id=placeholder_run.id,
# Keep the original API key or device scopes, if any
api_key=auth_context.api_key,
device=auth_context.device,
# Never expire the token
expires_in=0,
).access_token
environment = {
ENV_ZENML_ACTIVE_WORKSPACE_ID: str(new_deployment.workspace.id),
ENV_ZENML_ACTIVE_STACK_ID: str(stack.id),
"ZENML_VERSION": zenml_version,
"ZENML_STORE_URL": server_url,
"ZENML_STORE_TYPE": StoreType.REST.value,
"ZENML_STORE_API_TOKEN": api_token,
"ZENML_STORE_VERIFY_SSL": "True",
}
command = RunnerEntrypointConfiguration.get_entrypoint_command()
args = RunnerEntrypointConfiguration.get_entrypoint_arguments(
deployment_id=new_deployment.id
)
def _task() -> None:
pypi_requirements, apt_packages = (
requirements_utils.get_requirements_for_stack(stack=stack)
)
if build.python_version:
version_info = version.parse(build.python_version)
python_version = f"{version_info.major}.{version_info.minor}"
else:
python_version = (
f"{sys.version_info.major}.{sys.version_info.minor}"
)
dockerfile = generate_dockerfile(
pypi_requirements=pypi_requirements,
apt_packages=apt_packages,
zenml_version=zenml_version,
python_version=python_version,
)
image_hash = generate_image_hash(dockerfile=dockerfile)
runner_image = workload_manager().build_and_push_image(
workload_id=new_deployment.id,
dockerfile=dockerfile,
image_name=f"{RUNNER_IMAGE_REPOSITORY}:{image_hash}",
sync=True,
)
workload_manager().log(
workload_id=new_deployment.id,
message="Starting pipeline run.",
)
workload_manager().run(
workload_id=new_deployment.id,
image=runner_image,
command=command,
arguments=args,
environment=environment,
timeout_in_seconds=30,
sync=True,
)
workload_manager().log(
workload_id=new_deployment.id,
message="Pipeline run started successfully.",
)
def _task_with_analytics_and_error_handling() -> None:
with track_handler(
event=AnalyticsEvent.RUN_PIPELINE
) as analytics_handler:
analytics_handler.metadata = get_pipeline_run_analytics_metadata(
deployment=new_deployment,
stack=stack,
template_id=template.id,
run_id=placeholder_run.id,
)
try:
_task()
except Exception:
logger.exception(
"Failed to run template %s, run ID: %s",
str(template.id),
str(placeholder_run.id),
)
zen_store().update_run(
run_id=placeholder_run.id,
run_update=PipelineRunUpdate(
status=ExecutionStatus.FAILED
),
)
raise
if background_tasks:
background_tasks.add_task(_task_with_analytics_and_error_handling)
else:
# Run synchronously if no background tasks were passed. This is probably
# when coming from a trigger which itself is already running in the
# background
_task_with_analytics_and_error_handling()
return placeholder_run
workload_manager_interface
Workload manager interface definition.
WorkloadManagerInterface (ABC)
Workload manager interface.
Source code in zenml/zen_server/template_execution/workload_manager_interface.py
class WorkloadManagerInterface(ABC):
"""Workload manager interface."""
@abstractmethod
def run(
self,
workload_id: UUID,
image: str,
command: List[str],
arguments: List[str],
environment: Optional[Dict[str, str]] = None,
sync: bool = True,
timeout_in_seconds: int = 0,
) -> None:
"""Run a Docker container.
Args:
workload_id: Workload ID.
image: The Docker image to run.
command: The command to run in the container.
arguments: The arguments for the command.
environment: The environment to set in the container.
sync: If True, will wait until the container finished running before
returning.
timeout_in_seconds: Timeout in seconds to wait before cancelling
the container. If set to 0 the container will run until it
fails or finishes.
"""
pass
@abstractmethod
def build_and_push_image(
self,
workload_id: UUID,
dockerfile: str,
image_name: str,
sync: bool = True,
timeout_in_seconds: int = 0,
) -> str:
"""Build and push a Docker image.
Args:
workload_id: Workload ID.
dockerfile: The dockerfile content to build the image.
image_name: The image repository and tag.
sync: If True, will wait until the build finished before returning.
timeout_in_seconds: Timeout in seconds to wait before cancelling
the container. If set to 0 the container will run until it
fails or finishes.
Returns:
The full image name including container registry.
"""
pass
@abstractmethod
def delete_workload(self, workload_id: UUID) -> None:
"""Delete a workload.
Args:
workload_id: Workload ID.
"""
pass
@abstractmethod
def get_logs(self, workload_id: UUID) -> str:
"""Get logs for a workload.
Args:
workload_id: Workload ID.
Returns:
The stored logs.
"""
pass
@abstractmethod
def log(self, workload_id: UUID, message: str) -> None:
"""Log a message.
Args:
workload_id: Workload ID.
message: The message to log.
"""
pass
build_and_push_image(self, workload_id, dockerfile, image_name, sync=True, timeout_in_seconds=0)
Build and push a Docker image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workload_id |
UUID |
Workload ID. |
required |
dockerfile |
str |
The dockerfile content to build the image. |
required |
image_name |
str |
The image repository and tag. |
required |
sync |
bool |
If True, will wait until the build finished before returning. |
True |
timeout_in_seconds |
int |
Timeout in seconds to wait before cancelling the container. If set to 0 the container will run until it fails or finishes. |
0 |
Returns:
Type | Description |
---|---|
str |
The full image name including container registry. |
Source code in zenml/zen_server/template_execution/workload_manager_interface.py
@abstractmethod
def build_and_push_image(
self,
workload_id: UUID,
dockerfile: str,
image_name: str,
sync: bool = True,
timeout_in_seconds: int = 0,
) -> str:
"""Build and push a Docker image.
Args:
workload_id: Workload ID.
dockerfile: The dockerfile content to build the image.
image_name: The image repository and tag.
sync: If True, will wait until the build finished before returning.
timeout_in_seconds: Timeout in seconds to wait before cancelling
the container. If set to 0 the container will run until it
fails or finishes.
Returns:
The full image name including container registry.
"""
pass
delete_workload(self, workload_id)
Delete a workload.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workload_id |
UUID |
Workload ID. |
required |
Source code in zenml/zen_server/template_execution/workload_manager_interface.py
@abstractmethod
def delete_workload(self, workload_id: UUID) -> None:
"""Delete a workload.
Args:
workload_id: Workload ID.
"""
pass
get_logs(self, workload_id)
Get logs for a workload.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workload_id |
UUID |
Workload ID. |
required |
Returns:
Type | Description |
---|---|
str |
The stored logs. |
Source code in zenml/zen_server/template_execution/workload_manager_interface.py
@abstractmethod
def get_logs(self, workload_id: UUID) -> str:
"""Get logs for a workload.
Args:
workload_id: Workload ID.
Returns:
The stored logs.
"""
pass
log(self, workload_id, message)
Log a message.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workload_id |
UUID |
Workload ID. |
required |
message |
str |
The message to log. |
required |
Source code in zenml/zen_server/template_execution/workload_manager_interface.py
@abstractmethod
def log(self, workload_id: UUID, message: str) -> None:
"""Log a message.
Args:
workload_id: Workload ID.
message: The message to log.
"""
pass
run(self, workload_id, image, command, arguments, environment=None, sync=True, timeout_in_seconds=0)
Run a Docker container.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workload_id |
UUID |
Workload ID. |
required |
image |
str |
The Docker image to run. |
required |
command |
List[str] |
The command to run in the container. |
required |
arguments |
List[str] |
The arguments for the command. |
required |
environment |
Optional[Dict[str, str]] |
The environment to set in the container. |
None |
sync |
bool |
If True, will wait until the container finished running before returning. |
True |
timeout_in_seconds |
int |
Timeout in seconds to wait before cancelling the container. If set to 0 the container will run until it fails or finishes. |
0 |
Source code in zenml/zen_server/template_execution/workload_manager_interface.py
@abstractmethod
def run(
self,
workload_id: UUID,
image: str,
command: List[str],
arguments: List[str],
environment: Optional[Dict[str, str]] = None,
sync: bool = True,
timeout_in_seconds: int = 0,
) -> None:
"""Run a Docker container.
Args:
workload_id: Workload ID.
image: The Docker image to run.
command: The command to run in the container.
arguments: The arguments for the command.
environment: The environment to set in the container.
sync: If True, will wait until the container finished running before
returning.
timeout_in_seconds: Timeout in seconds to wait before cancelling
the container. If set to 0 the container will run until it
fails or finishes.
"""
pass
utils
Util functions for the ZenML Server.
connected_to_local_server()
Check if the client is connected to a local server.
Returns:
Type | Description |
---|---|
bool |
True if the client is connected to a local server, False otherwise. |
Source code in zenml/zen_server/utils.py
def connected_to_local_server() -> bool:
"""Check if the client is connected to a local server.
Returns:
True if the client is connected to a local server, False otherwise.
"""
from zenml.zen_server.deploy.deployer import LocalServerDeployer
deployer = LocalServerDeployer()
return deployer.is_connected_to_server()
feature_gate()
Return the initialized Feature Gate component.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the RBAC component is not initialized. |
Returns:
Type | Description |
---|---|
FeatureGateInterface |
The RBAC component. |
Source code in zenml/zen_server/utils.py
def feature_gate() -> FeatureGateInterface:
"""Return the initialized Feature Gate component.
Raises:
RuntimeError: If the RBAC component is not initialized.
Returns:
The RBAC component.
"""
global _feature_gate
if _feature_gate is None:
raise RuntimeError("Feature gate component not initialized.")
return _feature_gate
get_ip_location(ip_address)
Get the location of the given IP address.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ip_address |
str |
The IP address to get the location for. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str, str] |
A tuple of city, region, country. |
Source code in zenml/zen_server/utils.py
def get_ip_location(ip_address: str) -> Tuple[str, str, str]:
"""Get the location of the given IP address.
Args:
ip_address: The IP address to get the location for.
Returns:
A tuple of city, region, country.
"""
import ipinfo # type: ignore[import-untyped]
try:
handler = ipinfo.getHandler()
details = handler.getDetails(ip_address)
return (
details.city,
details.region,
details.country_name,
)
except Exception:
logger.exception(f"Could not get IP location for {ip_address}.")
return "", "", ""
get_local_server()
Get the active local server.
Call this function to retrieve the local server deployed on this machine.
Returns:
Type | Description |
---|---|
Optional[LocalServerDeployment] |
The local server deployment or None, if no local server deployment was found. |
Source code in zenml/zen_server/utils.py
def get_local_server() -> Optional["LocalServerDeployment"]:
"""Get the active local server.
Call this function to retrieve the local server deployed on this machine.
Returns:
The local server deployment or None, if no local server deployment was
found.
"""
from zenml.zen_server.deploy.deployer import LocalServerDeployer
deployer = LocalServerDeployer()
try:
return deployer.get_server()
except ServerDeploymentNotFoundError:
return None
handle_exceptions(func)
Decorator to handle exceptions in the API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
~F |
Function to decorate. |
required |
Returns:
Type | Description |
---|---|
~F |
Decorated function. |
Source code in zenml/zen_server/utils.py
def handle_exceptions(func: F) -> F:
"""Decorator to handle exceptions in the API.
Args:
func: Function to decorate.
Returns:
Decorated function.
"""
@wraps(func)
def decorated(*args: Any, **kwargs: Any) -> Any:
# These imports can't happen at module level as this module is also
# used by the CLI when installed without the `server` extra
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from zenml.zen_server.auth import AuthContext, set_auth_context
for arg in args:
if isinstance(arg, AuthContext):
set_auth_context(arg)
break
else:
for _, arg in kwargs.items():
if isinstance(arg, AuthContext):
set_auth_context(arg)
break
try:
return func(*args, **kwargs)
except OAuthError as error:
# The OAuthError is special because it needs to have a JSON response
return JSONResponse(
status_code=error.status_code,
content=error.to_dict(),
)
except HTTPException:
raise
except Exception as error:
logger.exception("API error")
http_exception = http_exception_from_error(error)
raise http_exception
return cast(F, decorated)
initialize_feature_gate()
Initialize the Feature Gate component.
Source code in zenml/zen_server/utils.py
def initialize_feature_gate() -> None:
"""Initialize the Feature Gate component."""
global _feature_gate
if (
feature_gate_source
:= server_config().feature_gate_implementation_source
):
from zenml.utils import source_utils
implementation_class = source_utils.load_and_validate_class(
feature_gate_source, expected_class=FeatureGateInterface
)
_feature_gate = implementation_class()
initialize_memcache(max_capacity, default_expiry)
Initialize the memory cache.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_capacity |
int |
The maximum capacity of the cache. |
required |
default_expiry |
int |
The default expiry time in seconds. |
required |
Source code in zenml/zen_server/utils.py
def initialize_memcache(max_capacity: int, default_expiry: int) -> None:
"""Initialize the memory cache.
Args:
max_capacity: The maximum capacity of the cache.
default_expiry: The default expiry time in seconds.
"""
global _memcache
_memcache = MemoryCache(max_capacity, default_expiry)
initialize_plugins()
Initialize the event plugins registry.
Source code in zenml/zen_server/utils.py
def initialize_plugins() -> None:
"""Initialize the event plugins registry."""
plugin_flavor_registry().initialize_plugins()
initialize_rbac()
Initialize the RBAC component.
Source code in zenml/zen_server/utils.py
def initialize_rbac() -> None:
"""Initialize the RBAC component."""
global _rbac
if rbac_source := server_config().rbac_implementation_source:
from zenml.utils import source_utils
implementation_class = source_utils.load_and_validate_class(
rbac_source, expected_class=RBACInterface
)
_rbac = implementation_class()
initialize_workload_manager()
Initialize the workload manager component.
This does not fail if the source can't be loaded but only logs a warning.
Source code in zenml/zen_server/utils.py
def initialize_workload_manager() -> None:
"""Initialize the workload manager component.
This does not fail if the source can't be loaded but only logs a warning.
"""
global _workload_manager
if source := server_config().workload_manager_implementation_source:
from zenml.utils import source_utils
try:
workload_manager_class: Type[WorkloadManagerInterface] = (
source_utils.load_and_validate_class(
source=source, expected_class=WorkloadManagerInterface
)
)
except (ModuleNotFoundError, KeyError):
logger.warning("Unable to load workload manager source.")
else:
_workload_manager = workload_manager_class()
initialize_zen_store()
Initialize the ZenML Store.
Exceptions:
Type | Description |
---|---|
ValueError |
If the ZenML Store is using a REST back-end. |
Source code in zenml/zen_server/utils.py
def initialize_zen_store() -> None:
"""Initialize the ZenML Store.
Raises:
ValueError: If the ZenML Store is using a REST back-end.
"""
logger.debug("Initializing ZenML Store for FastAPI...")
# Use an environment variable to flag the instance as a server
os.environ[ENV_ZENML_SERVER] = "true"
zen_store_ = GlobalConfiguration().zen_store
if not isinstance(zen_store_, SqlZenStore):
raise ValueError(
"Server cannot be started with a REST store type. Make sure you "
"configure ZenML to use a non-networked store backend "
"when trying to start the ZenML Server."
)
global _zen_store
_zen_store = zen_store_
is_user_request(request)
Determine if the incoming request is a user request.
This function checks various aspects of the request to determine if it's a user-initiated request or a system request.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The incoming FastAPI request object. |
required |
Returns:
Type | Description |
---|---|
bool |
True if it's a user request, False otherwise. |
Source code in zenml/zen_server/utils.py
def is_user_request(request: "Request") -> bool:
"""Determine if the incoming request is a user request.
This function checks various aspects of the request to determine
if it's a user-initiated request or a system request.
Args:
request: The incoming FastAPI request object.
Returns:
True if it's a user request, False otherwise.
"""
# Define system paths that should be excluded
system_paths: List[str] = [
"/health",
"/metrics",
"/system",
"/docs",
"/redoc",
"/openapi.json",
]
user_prefix = f"{API}{VERSION_1}"
excluded_user_apis = [INFO]
# Check if this is not an excluded endpoint
if request.url.path in [
user_prefix + suffix for suffix in excluded_user_apis
]:
return False
# Check if this is other user request
if request.url.path.startswith(user_prefix):
return True
# Exclude system paths
if any(request.url.path.startswith(path) for path in system_paths):
return False
# Exclude requests with specific headers
if request.headers.get("X-System-Request") == "true":
return False
# Exclude requests from certain user agents (e.g., monitoring tools)
user_agent = request.headers.get("User-Agent", "").lower()
system_agents = ["prometheus", "datadog", "newrelic", "pingdom"]
if any(agent in user_agent for agent in system_agents):
return False
# Check for internal IP addresses
client_host = request.client.host if request.client else None
if client_host and (
client_host.startswith("10.") or client_host.startswith("192.168.")
):
return False
# Exclude OPTIONS requests (often used for CORS preflight)
if request.method == "OPTIONS":
return False
# Exclude specific query parameters that might indicate system requests
if request.query_params.get("system_check"):
return False
# If none of the above conditions are met, consider it a user request
return True
make_dependable(cls)
This function makes a pydantic model usable for fastapi query parameters.
Additionally, it converts InternalServerError
s that would happen due to
pydantic.ValidationError
into 422 responses that signal an invalid
request.
Check out https://github.com/tiangolo/fastapi/issues/1474 for context.
!!! usage def f(model: Model = Depends(make_dependable(Model))): ...
UPDATE: Function from above mentioned Github issue was extended to support
multi-input parameters, e.g. tags: List[str]. It needs a default set to Query(
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cls |
Type[pydantic.main.BaseModel] |
The model class. |
required |
Returns:
Type | Description |
---|---|
Callable[..., Any] |
Function to use in FastAPI |
Source code in zenml/zen_server/utils.py
def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
"""This function makes a pydantic model usable for fastapi query parameters.
Additionally, it converts `InternalServerError`s that would happen due to
`pydantic.ValidationError` into 422 responses that signal an invalid
request.
Check out https://github.com/tiangolo/fastapi/issues/1474 for context.
Usage:
def f(model: Model = Depends(make_dependable(Model))):
...
UPDATE: Function from above mentioned Github issue was extended to support
multi-input parameters, e.g. tags: List[str]. It needs a default set to Query(<default>),
rather just plain <default>.
Args:
cls: The model class.
Returns:
Function to use in FastAPI `Depends`.
"""
from fastapi import Query
from zenml.zen_server.exceptions import error_detail
def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
from fastapi import HTTPException
try:
inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
return cls(*args, **kwargs)
except ValidationError as e:
detail = error_detail(e, exception_type=ValueError)
raise HTTPException(422, detail=detail)
params = {v.name: v for v in inspect.signature(cls).parameters.values()}
query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", [])
for qp in query_params:
if qp in params:
params[qp] = inspect.Parameter(
name=params[qp].name,
default=Query(params[qp].default),
kind=params[qp].kind,
annotation=params[qp].annotation,
)
init_cls_and_handle_errors.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=[v for v in params.values()]
)
return init_cls_and_handle_errors
memcache()
Return the memory cache.
Returns:
Type | Description |
---|---|
MemoryCache |
The memory cache. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the memory cache is not initialized. |
Source code in zenml/zen_server/utils.py
def memcache() -> MemoryCache:
"""Return the memory cache.
Returns:
The memory cache.
Raises:
RuntimeError: If the memory cache is not initialized.
"""
if _memcache is None:
raise RuntimeError("Memory cache not initialized")
return _memcache
plugin_flavor_registry()
Get the plugin flavor registry.
Returns:
Type | Description |
---|---|
PluginFlavorRegistry |
The plugin flavor registry. |
Source code in zenml/zen_server/utils.py
def plugin_flavor_registry() -> PluginFlavorRegistry:
"""Get the plugin flavor registry.
Returns:
The plugin flavor registry.
"""
global _plugin_flavor_registry
if _plugin_flavor_registry is None:
_plugin_flavor_registry = PluginFlavorRegistry()
_plugin_flavor_registry.initialize_plugins()
return _plugin_flavor_registry
rbac()
Return the initialized RBAC component.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the RBAC component is not initialized. |
Returns:
Type | Description |
---|---|
RBACInterface |
The RBAC component. |
Source code in zenml/zen_server/utils.py
def rbac() -> RBACInterface:
"""Return the initialized RBAC component.
Raises:
RuntimeError: If the RBAC component is not initialized.
Returns:
The RBAC component.
"""
global _rbac
if _rbac is None:
raise RuntimeError("RBAC component not initialized")
return _rbac
server_config()
Returns the ZenML Server configuration.
Returns:
Type | Description |
---|---|
ServerConfiguration |
The ZenML Server configuration. |
Source code in zenml/zen_server/utils.py
def server_config() -> ServerConfiguration:
"""Returns the ZenML Server configuration.
Returns:
The ZenML Server configuration.
"""
global _server_config
if _server_config is None:
_server_config = ServerConfiguration.get_server_config()
return _server_config
show_dashboard(local=False, ngrok_token=None)
Show the ZenML dashboard.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
local |
bool |
Whether to show the dashboard for the local server or the one for the active server. |
False |
ngrok_token |
Optional[str] |
An ngrok auth token to use for exposing the ZenML dashboard on a public domain. Primarily used for accessing the dashboard in Colab. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no server is connected. |
Source code in zenml/zen_server/utils.py
def show_dashboard(
local: bool = False,
ngrok_token: Optional[str] = None,
) -> None:
"""Show the ZenML dashboard.
Args:
local: Whether to show the dashboard for the local server or the
one for the active server.
ngrok_token: An ngrok auth token to use for exposing the ZenML
dashboard on a public domain. Primarily used for accessing the
dashboard in Colab.
Raises:
RuntimeError: If no server is connected.
"""
from zenml.utils.dashboard_utils import show_dashboard
from zenml.utils.networking_utils import get_or_create_ngrok_tunnel
url: Optional[str] = None
if not local:
gc = GlobalConfiguration()
if gc.store_configuration.type == StoreType.REST:
url = gc.store_configuration.url
if not url:
# Else, check for local servers
server = get_local_server()
if server and server.status and server.status.url:
url = server.status.url
if not url:
raise RuntimeError(
"ZenML is not connected to any server right now. Please use "
"`zenml login` to connect to a server or spin up a new local server "
"via `zenml login --local`."
)
if ngrok_token:
parsed_url = urlparse(url)
ngrok_url = get_or_create_ngrok_tunnel(
ngrok_token=ngrok_token, port=parsed_url.port or 80
)
logger.debug(f"Tunneling dashboard from {url} to {ngrok_url}.")
url = ngrok_url
show_dashboard(url)
verify_admin_status_if_no_rbac(admin_status, action=None)
Validate the admin status for sensitive requests.
Only add this check in endpoints meant for admin use only.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
admin_status |
Optional[bool] |
Whether the user is an admin or not. This is only used if explicitly specified in the call and even if passed will be ignored, if RBAC is enabled. |
required |
action |
Optional[str] |
The action that is being performed, used for output only. |
None |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the admin status is not valid. |
Source code in zenml/zen_server/utils.py
def verify_admin_status_if_no_rbac(
admin_status: Optional[bool],
action: Optional[str] = None,
) -> None:
"""Validate the admin status for sensitive requests.
Only add this check in endpoints meant for admin use only.
Args:
admin_status: Whether the user is an admin or not. This is only used
if explicitly specified in the call and even if passed will be
ignored, if RBAC is enabled.
action: The action that is being performed, used for output only.
Raises:
IllegalOperationError: If the admin status is not valid.
"""
if not server_config().rbac_enabled:
if not action:
action = "this action"
else:
action = f"`{action.strip('`')}`"
if admin_status is False:
raise IllegalOperationError(
message=f"Only admin users can perform {action} "
"without RBAC enabled.",
)
return
workload_manager()
Return the initialized workload manager component.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the workload manager component is not initialized. |
Returns:
Type | Description |
---|---|
WorkloadManagerInterface |
The workload manager component. |
Source code in zenml/zen_server/utils.py
def workload_manager() -> WorkloadManagerInterface:
"""Return the initialized workload manager component.
Raises:
RuntimeError: If the workload manager component is not initialized.
Returns:
The workload manager component.
"""
global _workload_manager
if _workload_manager is None:
raise RuntimeError("Workload manager component not initialized")
return _workload_manager
zen_store()
Initialize the ZenML Store.
Returns:
Type | Description |
---|---|
SqlZenStore |
The ZenML Store. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the ZenML Store has not been initialized. |
Source code in zenml/zen_server/utils.py
def zen_store() -> "SqlZenStore":
"""Initialize the ZenML Store.
Returns:
The ZenML Store.
Raises:
RuntimeError: If the ZenML Store has not been initialized.
"""
global _zen_store
if _zen_store is None:
raise RuntimeError("ZenML Store not initialized")
return _zen_store
zen_server_api
Zen Server API.
To run this file locally, execute:
```
uvicorn zenml.zen_server.zen_server_api:app --reload
```
RequestBodyLimit (BaseHTTPMiddleware)
Limits the size of the request body.
Source code in zenml/zen_server/zen_server_api.py
class RequestBodyLimit(BaseHTTPMiddleware):
"""Limits the size of the request body."""
def __init__(self, app: ASGIApp, max_bytes: int) -> None:
"""Limits the size of the request body.
Args:
app: The FastAPI app.
max_bytes: The maximum size of the request body.
"""
super().__init__(app)
self.max_bytes = max_bytes
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
"""Limits the size of the request body.
Args:
request: The incoming request.
call_next: The next function to be called.
Returns:
The response to the request.
"""
if content_length := request.headers.get("content-length"):
if int(content_length) > self.max_bytes:
return Response(status_code=413) # Request Entity Too Large
try:
return await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
__init__(self, app, max_bytes)
special
Limits the size of the request body.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
app |
Callable[[MutableMapping[str, Any], Callable[[], Awaitable[MutableMapping[str, Any]]], Callable[[MutableMapping[str, Any]], Awaitable[NoneType]]], Awaitable[NoneType]] |
The FastAPI app. |
required |
max_bytes |
int |
The maximum size of the request body. |
required |
Source code in zenml/zen_server/zen_server_api.py
def __init__(self, app: ASGIApp, max_bytes: int) -> None:
"""Limits the size of the request body.
Args:
app: The FastAPI app.
max_bytes: The maximum size of the request body.
"""
super().__init__(app)
self.max_bytes = max_bytes
dispatch(self, request, call_next)
async
Limits the size of the request body.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The incoming request. |
required |
call_next |
Callable[[starlette.requests.Request], Awaitable[starlette.responses.Response]] |
The next function to be called. |
required |
Returns:
Type | Description |
---|---|
Response |
The response to the request. |
Source code in zenml/zen_server/zen_server_api.py
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
"""Limits the size of the request body.
Args:
request: The incoming request.
call_next: The next function to be called.
Returns:
The response to the request.
"""
if content_length := request.headers.get("content-length"):
if int(content_length) > self.max_bytes:
return Response(status_code=413) # Request Entity Too Large
try:
return await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
RestrictFileUploadsMiddleware (BaseHTTPMiddleware)
Restrict file uploads to certain paths.
Source code in zenml/zen_server/zen_server_api.py
class RestrictFileUploadsMiddleware(BaseHTTPMiddleware):
"""Restrict file uploads to certain paths."""
def __init__(self, app: FastAPI, allowed_paths: Set[str]):
"""Restrict file uploads to certain paths.
Args:
app: The FastAPI app.
allowed_paths: The allowed paths.
"""
super().__init__(app)
self.allowed_paths = allowed_paths
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
"""Restrict file uploads to certain paths.
Args:
request: The incoming request.
call_next: The next function to be called.
Returns:
The response to the request.
"""
if request.method == "POST":
content_type = request.headers.get("content-type", "")
if (
"multipart/form-data" in content_type
and request.url.path not in self.allowed_paths
):
return JSONResponse(
status_code=403,
content={
"detail": "File uploads are not allowed on this endpoint."
},
)
try:
return await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
__init__(self, app, allowed_paths)
special
Restrict file uploads to certain paths.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
app |
FastAPI |
The FastAPI app. |
required |
allowed_paths |
Set[str] |
The allowed paths. |
required |
Source code in zenml/zen_server/zen_server_api.py
def __init__(self, app: FastAPI, allowed_paths: Set[str]):
"""Restrict file uploads to certain paths.
Args:
app: The FastAPI app.
allowed_paths: The allowed paths.
"""
super().__init__(app)
self.allowed_paths = allowed_paths
dispatch(self, request, call_next)
async
Restrict file uploads to certain paths.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The incoming request. |
required |
call_next |
Callable[[starlette.requests.Request], Awaitable[starlette.responses.Response]] |
The next function to be called. |
required |
Returns:
Type | Description |
---|---|
Response |
The response to the request. |
Source code in zenml/zen_server/zen_server_api.py
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
"""Restrict file uploads to certain paths.
Args:
request: The incoming request.
call_next: The next function to be called.
Returns:
The response to the request.
"""
if request.method == "POST":
content_type = request.headers.get("content-type", "")
if (
"multipart/form-data" in content_type
and request.url.path not in self.allowed_paths
):
return JSONResponse(
status_code=403,
content={
"detail": "File uploads are not allowed on this endpoint."
},
)
try:
return await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
catch_all(request, file_path)
async
Dashboard endpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
file_path |
str |
Path to a file in the dashboard root folder. |
required |
Returns:
Type | Description |
---|---|
Any |
The ZenML dashboard. |
Source code in zenml/zen_server/zen_server_api.py
@app.get("/{file_path:path}", include_in_schema=False)
async def catch_all(request: Request, file_path: str) -> Any:
"""Dashboard endpoint.
Args:
request: Request object.
file_path: Path to a file in the dashboard root folder.
Returns:
The ZenML dashboard.
"""
if DASHBOARD_REDIRECT_URL:
return RedirectResponse(url=DASHBOARD_REDIRECT_URL)
# some static files need to be served directly from the root dashboard
# directory
if file_path and file_path in root_static_files:
logger.debug(f"Returning static file: {file_path}")
full_path = os.path.join(relative_path(DASHBOARD_DIRECTORY), file_path)
return FileResponse(full_path)
# everything else is directed to the index.html file that hosts the
# single-page application
return templates.TemplateResponse("index.html", {"request": request})
dashboard(request)
async
Dashboard endpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
Returns:
Type | Description |
---|---|
Any |
The ZenML dashboard. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the dashboard files are not included. |
Source code in zenml/zen_server/zen_server_api.py
@app.get("/", include_in_schema=False)
async def dashboard(request: Request) -> Any:
"""Dashboard endpoint.
Args:
request: Request object.
Returns:
The ZenML dashboard.
Raises:
HTTPException: If the dashboard files are not included.
"""
if DASHBOARD_REDIRECT_URL:
return RedirectResponse(url=DASHBOARD_REDIRECT_URL)
if not os.path.isfile(
os.path.join(relative_path(DASHBOARD_DIRECTORY), "index.html")
):
raise HTTPException(status_code=404)
return templates.TemplateResponse("index.html", {"request": request})
get_root_static_files()
Get the list of static files in the root dashboard directory.
These files are static files that are not in the /static subdirectory that need to be served as static files under the root URL path.
Returns:
Type | Description |
---|---|
List[str] |
List of static files in the root directory. |
Source code in zenml/zen_server/zen_server_api.py
def get_root_static_files() -> List[str]:
"""Get the list of static files in the root dashboard directory.
These files are static files that are not in the /static subdirectory
that need to be served as static files under the root URL path.
Returns:
List of static files in the root directory.
"""
root_path = relative_path(DASHBOARD_DIRECTORY)
if not os.path.isdir(root_path):
return []
files = []
for file in os.listdir(root_path):
if file == "index.html":
# this is served separately
continue
if isfile(os.path.join(root_path, file)):
files.append(file)
return files
health()
async
Get health status of the server.
Returns:
Type | Description |
---|---|
str |
String representing the health status of the server. |
Source code in zenml/zen_server/zen_server_api.py
@app.head(HEALTH, include_in_schema=False)
@app.get(HEALTH)
async def health() -> str:
"""Get health status of the server.
Returns:
String representing the health status of the server.
"""
return "OK"
infer_source_context(request, call_next)
async
A middleware to track the source of an event.
It extracts the source context from the header of incoming requests and applies it to the ZenML source context on the API side. This way, the outgoing analytics request can append it as an additional field.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
the incoming request object. |
required |
call_next |
Any |
a function that will receive the request as a parameter and pass it to the corresponding path operation. |
required |
Returns:
Type | Description |
---|---|
Any |
the response to the request. |
Source code in zenml/zen_server/zen_server_api.py
@app.middleware("http")
async def infer_source_context(request: Request, call_next: Any) -> Any:
"""A middleware to track the source of an event.
It extracts the source context from the header of incoming requests
and applies it to the ZenML source context on the API side. This way, the
outgoing analytics request can append it as an additional field.
Args:
request: the incoming request object.
call_next: a function that will receive the request as a parameter and
pass it to the corresponding path operation.
Returns:
the response to the request.
"""
try:
s = request.headers.get(
source_context.name,
default=SourceContextTypes.API.value,
)
source_context.set(SourceContextTypes(s))
except Exception as e:
logger.warning(
f"An unexpected error occurred while getting the source "
f"context: {e}"
)
source_context.set(SourceContextTypes.API)
try:
return await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
initialize()
Initialize the ZenML server.
Source code in zenml/zen_server/zen_server_api.py
@app.on_event("startup")
def initialize() -> None:
"""Initialize the ZenML server."""
cfg = server_config()
# Set the maximum number of worker threads
to_thread.current_default_thread_limiter().total_tokens = (
cfg.thread_pool_size
)
# IMPORTANT: these need to be run before the fastapi app starts, to avoid
# race conditions
initialize_zen_store()
initialize_rbac()
initialize_feature_gate()
initialize_workload_manager()
initialize_plugins()
initialize_secure_headers()
initialize_memcache(cfg.memcache_max_capacity, cfg.memcache_default_expiry)
invalid_api(invalid_api_path)
async
Invalid API endpoint.
All API endpoints that are not defined in the API routers will be redirected to this endpoint and will return a 404 error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
invalid_api_path |
str |
Invalid API path. |
required |
Exceptions:
Type | Description |
---|---|
HTTPException |
404 error. |
Source code in zenml/zen_server/zen_server_api.py
@app.get(
API + "/{invalid_api_path:path}", status_code=404, include_in_schema=False
)
async def invalid_api(invalid_api_path: str) -> None:
"""Invalid API endpoint.
All API endpoints that are not defined in the API routers will be
redirected to this endpoint and will return a 404 error.
Args:
invalid_api_path: Invalid API path.
Raises:
HTTPException: 404 error.
"""
logger.debug(f"Invalid API path requested: {invalid_api_path}")
raise HTTPException(status_code=404)
relative_path(rel)
Get the absolute path of a path relative to the ZenML server module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rel |
str |
Relative path. |
required |
Returns:
Type | Description |
---|---|
str |
Absolute path. |
Source code in zenml/zen_server/zen_server_api.py
def relative_path(rel: str) -> str:
"""Get the absolute path of a path relative to the ZenML server module.
Args:
rel: Relative path.
Returns:
Absolute path.
"""
return os.path.join(os.path.dirname(__file__), rel)
set_secure_headers(request, call_next)
async
Middleware to set secure headers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The incoming request. |
required |
call_next |
Any |
The next function to be called. |
required |
Returns:
Type | Description |
---|---|
Any |
The response with secure headers set. |
Source code in zenml/zen_server/zen_server_api.py
@app.middleware("http")
async def set_secure_headers(request: Request, call_next: Any) -> Any:
"""Middleware to set secure headers.
Args:
request: The incoming request.
call_next: The next function to be called.
Returns:
The response with secure headers set.
"""
try:
response = await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
response = JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
# If the request is for the openAPI docs, don't set secure headers
if request.url.path.startswith("/docs") or request.url.path.startswith(
"/redoc"
):
return response
secure_headers().framework.fastapi(response)
return response
track_last_user_activity(request, call_next)
async
A middleware to track last user activity.
This middleware checks if the incoming request is a user request and updates the last activity timestamp if it is.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The incoming request object. |
required |
call_next |
Any |
A function that will receive the request as a parameter and pass it to the corresponding path operation. |
required |
Returns:
Type | Description |
---|---|
Any |
The response to the request. |
Source code in zenml/zen_server/zen_server_api.py
@app.middleware("http")
async def track_last_user_activity(request: Request, call_next: Any) -> Any:
"""A middleware to track last user activity.
This middleware checks if the incoming request is a user request and
updates the last activity timestamp if it is.
Args:
request: The incoming request object.
call_next: A function that will receive the request as a parameter and
pass it to the corresponding path operation.
Returns:
The response to the request.
"""
global last_user_activity
global last_user_activity_reported
try:
if is_user_request(request):
last_user_activity = datetime.now(timezone.utc)
except Exception as e:
logger.debug(
f"An unexpected error occurred while checking user activity: {e}"
)
if (
(
datetime.now(timezone.utc) - last_user_activity_reported
).total_seconds()
> DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS
):
last_user_activity_reported = datetime.now(timezone.utc)
zen_store()._update_last_user_activity_timestamp(
last_user_activity=last_user_activity
)
try:
return await call_next(request)
except Exception:
logger.exception("An error occurred while processing the request")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
validation_exception_handler(request, exc)
Custom validation exception handler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Any |
The request. |
required |
exc |
Exception |
The exception. |
required |
Returns:
Type | Description |
---|---|
ORJSONResponse |
The error response formatted using the ZenML API conventions. |
Source code in zenml/zen_server/zen_server_api.py
@app.exception_handler(RequestValidationError)
def validation_exception_handler(
request: Any, exc: Exception
) -> ORJSONResponse:
"""Custom validation exception handler.
Args:
request: The request.
exc: The exception.
Returns:
The error response formatted using the ZenML API conventions.
"""
return ORJSONResponse(error_detail(exc, ValueError), status_code=422)