Zen Server
zenml.zen_server
special
ZenML Server Implementation.
The ZenML Server is a centralized service meant for use in a collaborative setting in which stacks, stack components, flavors, pipeline and pipeline runs can be shared over the network with other users.
You can use the zenml server up
command to spin up ZenML server instances
that are either running locally as daemon processes or docker containers, or
to deploy a ZenML server remotely on a managed cloud platform. The other CLI
commands in the same zenml server
group can be used to manage the server
instances deployed from your local machine.
To connect the local ZenML client to one of the managed ZenML servers, call
zenml server connect
with the name of the server you want to connect to.
auth
Authentication module for ZenML server.
AuthContext (BaseModel)
pydantic-model
The authentication context.
Source code in zenml/zen_server/auth.py
class AuthContext(BaseModel):
"""The authentication context."""
user: UserResponseModel
access_token: Optional[JWTToken] = None
encoded_access_token: Optional[str] = None
device: Optional[OAuthDeviceInternalResponseModel] = None
@property
def permissions(self) -> Set[PermissionType]:
"""Returns the permissions of the user.
Returns:
The permissions of the user.
"""
if self.user.roles:
# Merge permissions from all roles
permissions: List[PermissionType] = []
for role in self.user.roles:
permissions.extend(role.permissions)
# Remove duplicates
return set(permissions)
return set()
permissions: Set[zenml.enums.PermissionType]
property
readonly
Returns the permissions of the user.
Returns:
Type | Description |
---|---|
Set[zenml.enums.PermissionType] |
The permissions of the user. |
CookieOAuth2TokenBearer (OAuth2PasswordBearer)
OAuth2 token bearer authentication scheme that uses a cookie.
Source code in zenml/zen_server/auth.py
class CookieOAuth2TokenBearer(OAuth2PasswordBearer):
"""OAuth2 token bearer authentication scheme that uses a cookie."""
async def __call__(self, request: Request) -> Optional[str]:
"""Extract the bearer token from the request.
Args:
request: The request.
Returns:
The bearer token extracted from the request cookie or header.
"""
# First, try to get the token from the cookie
authorization = request.cookies.get(
server_config().get_auth_cookie_name()
)
if authorization:
logger.info("Got token from cookie")
return authorization
# If the token is not present in the cookie, try to get it from the
# Authorization header
return await super().__call__(request)
__call__(self, request)
async
special
Extract the bearer token from the request.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The bearer token extracted from the request cookie or header. |
Source code in zenml/zen_server/auth.py
async def __call__(self, request: Request) -> Optional[str]:
"""Extract the bearer token from the request.
Args:
request: The request.
Returns:
The bearer token extracted from the request cookie or header.
"""
# First, try to get the token from the cookie
authorization = request.cookies.get(
server_config().get_auth_cookie_name()
)
if authorization:
logger.info("Got token from cookie")
return authorization
# If the token is not present in the cookie, try to get it from the
# Authorization header
return await super().__call__(request)
authenticate_credentials(user_name_or_id=None, password=None, access_token=None, activation_token=None)
Verify if user authentication credentials are valid.
This function can be used to validate all supplied user credentials to cover a range of possibilities:
- username only - only when the no-auth scheme is used
- username+password - for basic HTTP authentication or the OAuth2 password grant
- access token (with embedded user id) - after successful authentication using one of the supported grants
- username+activation token - for user activation
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
The username or user ID. |
None |
password |
Optional[str] |
The password. |
None |
access_token |
Optional[str] |
The access token. |
None |
activation_token |
Optional[str] |
The activation token. |
None |
Returns:
Type | Description |
---|---|
AuthContext |
The authenticated account details. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the credentials are invalid. |
Source code in zenml/zen_server/auth.py
def authenticate_credentials(
user_name_or_id: Optional[Union[str, UUID]] = None,
password: Optional[str] = None,
access_token: Optional[str] = None,
activation_token: Optional[str] = None,
) -> AuthContext:
"""Verify if user authentication credentials are valid.
This function can be used to validate all supplied user credentials to
cover a range of possibilities:
* username only - only when the no-auth scheme is used
* username+password - for basic HTTP authentication or the OAuth2 password
grant
* access token (with embedded user id) - after successful authentication
using one of the supported grants
* username+activation token - for user activation
Args:
user_name_or_id: The username or user ID.
password: The password.
access_token: The access token.
activation_token: The activation token.
Returns:
The authenticated account details.
Raises:
AuthorizationException: If the credentials are invalid.
"""
user: Optional[UserAuthModel] = None
auth_context: Optional[AuthContext] = None
if user_name_or_id:
try:
user = zen_store().get_auth_user(user_name_or_id)
user_model = zen_store().get_user(
user_name_or_id=user_name_or_id, include_private=True
)
auth_context = AuthContext(user=user_model)
except KeyError:
# even when the user does not exist, we still want to execute the
# password/token verification to protect against response discrepancy
# attacks (https://cwe.mitre.org/data/definitions/204.html)
logger.exception(
f"Authentication error: error retrieving user "
f"{user_name_or_id}"
)
pass
if password is not None:
if not UserAuthModel.verify_password(password, user):
error = "Authentication error: invalid username or password"
logger.error(error)
raise AuthorizationException(error)
elif access_token is not None:
try:
decoded_token = JWTToken.decode_token(
token=access_token,
)
except AuthorizationException:
error = "Authentication error: error decoding access token"
logger.exception(error)
raise AuthorizationException(error)
try:
user_model = zen_store().get_user(
user_name_or_id=decoded_token.user_id, include_private=True
)
except KeyError:
error = (
f"Authentication error: error retrieving token user "
f"{decoded_token.user_id}"
)
logger.error(error)
raise AuthorizationException(error)
if not user_model.active:
error = (
f"Authentication error: user {decoded_token.user_id} is not "
f"active"
)
logger.error(error)
raise AuthorizationException(error)
device_model: Optional[OAuthDeviceInternalResponseModel] = None
if decoded_token.device_id:
# Access tokens that have been issued for a device are only valid
# for that device, so we need to check if the device ID matches any
# of the valid devices in the database.
try:
device_model = zen_store().get_internal_authorized_device(
device_id=decoded_token.device_id
)
except KeyError:
error = (
f"Authentication error: error retrieving token device "
f"{decoded_token.device_id}"
)
logger.error(error)
raise AuthorizationException(error)
if (
device_model.user is None
or device_model.user.id != user_model.id
):
error = (
f"Authentication error: device {decoded_token.device_id} "
f"does not belong to user {user_model.id}"
)
logger.error(error)
raise AuthorizationException(error)
if device_model.status != OAuthDeviceStatus.ACTIVE:
error = (
f"Authentication error: device {decoded_token.device_id} "
f"is not active"
)
logger.error(error)
raise AuthorizationException(error)
if (
device_model.expires
and datetime.utcnow() >= device_model.expires
):
error = (
f"Authentication error: device {decoded_token.device_id} "
"has expired"
)
logger.error(error)
raise AuthorizationException(error)
zen_store().update_internal_authorized_device(
device_id=device_model.id,
update=OAuthDeviceInternalUpdateModel(
update_last_login=True,
),
)
auth_context = AuthContext(
user=user_model,
access_token=decoded_token,
encoded_access_token=access_token,
device=device_model,
)
elif activation_token is not None:
if not UserAuthModel.verify_activation_token(activation_token, user):
error = (
f"Authentication error: invalid activation token for user "
f"{user_name_or_id}"
)
logger.error(error)
raise AuthorizationException(error)
if not auth_context:
error = "Authentication error: invalid credentials"
logger.error(error)
raise AuthorizationException(error)
return auth_context
authenticate_device(client_id, device_code)
Verify if device authorization credentials are valid.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
client_id |
UUID |
The OAuth2 client ID. |
required |
device_code |
str |
The device code. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authenticated account details. |
Exceptions:
Type | Description |
---|---|
OAuthError |
If the device authorization credentials are invalid. |
Source code in zenml/zen_server/auth.py
def authenticate_device(client_id: UUID, device_code: str) -> AuthContext:
"""Verify if device authorization credentials are valid.
Args:
client_id: The OAuth2 client ID.
device_code: The device code.
Returns:
The authenticated account details.
Raises:
OAuthError: If the device authorization credentials are invalid.
"""
# This is the part of the OAuth2 device code grant flow where a client
# device is continuously polling the server to check if the user has
# authorized a device. The following needs to happen to successfully
# authenticate the device and return a valid access token:
#
# 1. the device code and client ID must match a device in the DB
# 2. the device must be in the VERIFIED state, meaning that the user
# has successfully authorized the device via the user code but the
# device client hasn't yet fetched the associated API access token yet.
# 3. the device must not be expired
config = server_config()
store = zen_store()
try:
device_model = store.get_internal_authorized_device(
client_id=client_id
)
except KeyError:
error = (
f"Authentication error: error retrieving device with client ID "
f"{client_id}"
)
logger.error(error)
raise OAuthError(
error="invalid_client",
error_description=error,
)
if device_model.status != OAuthDeviceStatus.VERIFIED:
error = (
f"Authentication error: device with client ID {client_id} is "
f"{device_model.status.value}."
)
logger.error(error)
if device_model.status == OAuthDeviceStatus.PENDING:
oauth_error = "authorization_pending"
elif device_model.status == OAuthDeviceStatus.LOCKED:
oauth_error = "access_denied"
else:
oauth_error = "expired_token"
raise OAuthError(
error=oauth_error,
error_description=error,
)
if device_model.expires and datetime.utcnow() >= device_model.expires:
error = (
f"Authentication error: device for client ID {client_id} has "
"expired"
)
logger.error(error)
raise OAuthError(
error="expired_token",
error_description=error,
)
# Check the device code
if not device_model.verify_device_code(device_code):
# If the device code is invalid, increment the failed auth attempts
# counter and lock the device if the maximum number of failed auth
# attempts has been reached.
failed_auth_attempts = device_model.failed_auth_attempts + 1
update = OAuthDeviceInternalUpdateModel(
failed_auth_attempts=failed_auth_attempts
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
update.locked = True
store.update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
error = (
f"Authentication error: device for client ID {client_id} "
"has been locked due to too many failed authentication "
"attempts."
)
else:
error = (
f"Authentication error: device for client ID {client_id} "
"has an invalid device code."
)
logger.error(error)
raise OAuthError(
error="access_denied",
error_description=error,
)
# The device is valid, so we can return the user associated with it.
# This is the one and only time we return an AuthContext authorized by
# a device code in order to be exchanged for an access token. Subsequent
# requests to the API will be authenticated using the access token.
#
# Update the device state to ACTIVE and set an expiration date for it
# past which it can no longer be used for authentication. The expiration
# date also determines the expiration date of the access token issued
# for this device.
expires_in: int = 0
if config.jwt_token_expire_minutes:
if device_model.trusted_device:
expires_in = config.trusted_device_expiration_minutes or 0
else:
expires_in = config.device_expiration_minutes or 0
update = OAuthDeviceInternalUpdateModel(
status=OAuthDeviceStatus.ACTIVE,
expires_in=expires_in * 60,
)
device_model = zen_store().update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
# This can never happen because the VERIFIED state is only set if
# a user verified and has been associated with the device.
assert device_model.user is not None
return AuthContext(user=device_model.user, device=device_model)
authenticate_external_user(external_access_token)
Implement external authentication.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
external_access_token |
str |
The access token used to authenticate the user to the external authenticator. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the external user could not be authorized. |
Source code in zenml/zen_server/auth.py
def authenticate_external_user(external_access_token: str) -> AuthContext:
"""Implement external authentication.
Args:
external_access_token: The access token used to authenticate the user
to the external authenticator.
Returns:
The authentication context reflecting the authenticated user.
Raises:
AuthorizationException: If the external user could not be authorized.
"""
config = server_config()
store = zen_store()
assert config.external_user_info_url is not None
# Use the external access token to extract the user information and
# permissions
# Get the user information from the external authenticator
user_info_url = config.external_user_info_url
headers = {"Authorization": "Bearer " + external_access_token}
query_params = dict(server_id=str(config.get_external_server_id()))
try:
auth_response = requests.get(
user_info_url,
headers=headers,
params=urlencode(query_params),
timeout=EXTERNAL_AUTHENTICATOR_TIMEOUT,
)
except Exception as e:
logger.exception(
f"Error fetching user information from external authenticator: "
f"{e}"
)
raise AuthorizationException(
"Error fetching user information from external authenticator."
)
external_user: Optional[ExternalUserModel] = None
if 200 <= auth_response.status_code < 300:
try:
payload = auth_response.json()
except requests.exceptions.JSONDecodeError:
logger.exception(
"Error decoding JSON response from external authenticator."
)
raise AuthorizationException(
"Unknown external authenticator error"
)
if isinstance(payload, dict):
try:
external_user = ExternalUserModel.parse_obj(payload)
except Exception as e:
logger.exception(
f"Error parsing user information from external "
f"authenticator: {e}"
)
pass
elif auth_response.status_code in [401, 403]:
raise AuthorizationException("Not authorized to access this server.")
elif auth_response.status_code == 404:
raise AuthorizationException(
"External authenticator did not recognize this server."
)
else:
logger.error(
f"Error fetching user information from external authenticator. "
f"Status code: {auth_response.status_code}, "
f"Response: {auth_response.text}"
)
raise AuthorizationException(
"Error fetching user information from external authenticator. "
)
if not external_user:
raise AuthorizationException("Unknown external authenticator error")
# With an external user object, we can now authenticate the user against
# the ZenML server
# Check if the external user already exists in the ZenML server database
# If not, create a new user. If yes, update the existing user.
try:
user = store.get_external_user(user_id=external_user.id)
# Update the user information
user = store.update_user(
user_id=user.id,
user_update=UserUpdateModel(
name=external_user.email,
full_name=external_user.name or "",
email_opted_in=True,
active=True,
email=external_user.email,
),
)
except KeyError:
logger.info(
f"External user with ID {external_user.id} not found in ZenML "
f"server database. Creating a new user."
)
user = store.create_user(
UserRequestModel(
name=external_user.email,
full_name=external_user.name or "",
external_user_id=external_user.id,
email_opted_in=True,
active=True,
email=external_user.email,
)
)
with AnalyticsContext() as context:
context.user_id = user.id
context.identify(
traits={"email": user.email, "source": "external_auth"}
)
context.alias(user_id=user.id, previous_id=external_user.id)
# Create a new user role assignment for the new user
store.create_user_role_assignment(
UserRoleAssignmentRequestModel(
role=store._admin_role.id,
user=user.id,
workspace=None,
)
)
return AuthContext(user=user)
authentication_provider()
Returns the authentication provider.
Returns:
Type | Description |
---|---|
Callable[..., zenml.zen_server.auth.AuthContext] |
The authentication provider. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the authentication scheme is not supported. |
Source code in zenml/zen_server/auth.py
def authentication_provider() -> Callable[..., AuthContext]:
"""Returns the authentication provider.
Returns:
The authentication provider.
Raises:
ValueError: If the authentication scheme is not supported.
"""
auth_scheme = server_config().auth_scheme
if auth_scheme == AuthScheme.NO_AUTH:
return no_authentication
elif auth_scheme == AuthScheme.HTTP_BASIC:
return http_authentication
elif auth_scheme == AuthScheme.OAUTH2_PASSWORD_BEARER:
return oauth2_authentication
elif auth_scheme == AuthScheme.EXTERNAL:
return oauth2_authentication
else:
raise ValueError(f"Unknown authentication scheme: {auth_scheme}")
authorize(security_scopes, token=Depends(CookieOAuth2TokenBearer))
Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
security_scopes |
SecurityScopes |
Security scope for this token |
required |
token |
str |
The JWT bearer token to be authenticated. |
Depends(CookieOAuth2TokenBearer) |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the JWT token could not be authorized. |
Source code in zenml/zen_server/auth.py
def oauth2_authentication(
security_scopes: SecurityScopes,
token: str = Depends(
CookieOAuth2TokenBearer(
tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
scopes={
"read": "Read permissions on all entities",
"write": "Write permissions on all entities",
"me": "Editing permissions to own user",
},
)
),
) -> AuthContext:
"""Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Args:
security_scopes: Security scope for this token
token: The JWT bearer token to be authenticated.
Returns:
The authentication context reflecting the authenticated user.
Raises:
HTTPException: If the JWT token could not be authorized.
"""
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
try:
auth_context = authenticate_credentials(access_token=token)
except AuthorizationException as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": authenticate_value},
)
for scope in security_scopes.scopes:
if (
auth_context.access_token
and scope not in auth_context.access_token.permissions
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return auth_context
get_auth_context()
Returns the current authentication context.
Returns:
Type | Description |
---|---|
Optional[AuthContext] |
The authentication context. |
Source code in zenml/zen_server/auth.py
def get_auth_context() -> Optional["AuthContext"]:
"""Returns the current authentication context.
Returns:
The authentication context.
"""
auth_context = _auth_context.get()
return auth_context
http_authentication(security_scopes, credentials=Depends(HTTPBasic))
Authenticates any request to the ZenML Server with basic HTTP authentication.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
security_scopes |
SecurityScopes |
Security scope will be ignored for http_auth |
required |
credentials |
HTTPBasicCredentials |
HTTP basic auth credentials passed to the request. |
Depends(HTTPBasic) |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the credentials are invalid. |
Source code in zenml/zen_server/auth.py
def http_authentication(
security_scopes: SecurityScopes,
credentials: HTTPBasicCredentials = Depends(HTTPBasic()),
) -> AuthContext:
"""Authenticates any request to the ZenML Server with basic HTTP authentication.
Args:
security_scopes: Security scope will be ignored for http_auth
credentials: HTTP basic auth credentials passed to the request.
Returns:
The authentication context reflecting the authenticated user.
Raises:
HTTPException: If the credentials are invalid.
"""
try:
return authenticate_credentials(
user_name_or_id=credentials.username, password=credentials.password
)
except AuthorizationException as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Basic"},
)
no_authentication(security_scopes)
Doesn't authenticate requests to the ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
security_scopes |
SecurityScopes |
Security scope will be ignored for http_auth |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the default user. |
Source code in zenml/zen_server/auth.py
def no_authentication(security_scopes: SecurityScopes) -> AuthContext:
"""Doesn't authenticate requests to the ZenML server.
Args:
security_scopes: Security scope will be ignored for http_auth
Returns:
The authentication context reflecting the default user.
"""
return authenticate_credentials(user_name_or_id=DEFAULT_USERNAME)
oauth2_authentication(security_scopes, token=Depends(CookieOAuth2TokenBearer))
Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
security_scopes |
SecurityScopes |
Security scope for this token |
required |
token |
str |
The JWT bearer token to be authenticated. |
Depends(CookieOAuth2TokenBearer) |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context reflecting the authenticated user. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the JWT token could not be authorized. |
Source code in zenml/zen_server/auth.py
def oauth2_authentication(
security_scopes: SecurityScopes,
token: str = Depends(
CookieOAuth2TokenBearer(
tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
scopes={
"read": "Read permissions on all entities",
"write": "Write permissions on all entities",
"me": "Editing permissions to own user",
},
)
),
) -> AuthContext:
"""Authenticates any request to the ZenML server with OAuth2 JWT tokens.
Args:
security_scopes: Security scope for this token
token: The JWT bearer token to be authenticated.
Returns:
The authentication context reflecting the authenticated user.
Raises:
HTTPException: If the JWT token could not be authorized.
"""
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
try:
auth_context = authenticate_credentials(access_token=token)
except AuthorizationException as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": authenticate_value},
)
for scope in security_scopes.scopes:
if (
auth_context.access_token
and scope not in auth_context.access_token.permissions
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return auth_context
set_auth_context(auth_context)
Sets the current authentication context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_context |
AuthContext |
The authentication context. |
required |
Returns:
Type | Description |
---|---|
AuthContext |
The authentication context. |
Source code in zenml/zen_server/auth.py
def set_auth_context(auth_context: "AuthContext") -> "AuthContext":
"""Sets the current authentication context.
Args:
auth_context: The authentication context.
Returns:
The authentication context.
"""
_auth_context.set(auth_context)
return auth_context
deploy
special
ZenML server deployments.
base_provider
Base ZenML server provider class.
BaseServerProvider (ABC)
Base ZenML server provider class.
All ZenML server providers must extend and implement this base class.
Source code in zenml/zen_server/deploy/base_provider.py
class BaseServerProvider(ABC):
"""Base ZenML server provider class.
All ZenML server providers must extend and implement this base class.
"""
TYPE: ClassVar[ServerProviderType]
CONFIG_TYPE: ClassVar[
Type[ServerDeploymentConfig]
] = ServerDeploymentConfig
@classmethod
def register_as_provider(cls) -> None:
"""Register the class as a server provider."""
from zenml.zen_server.deploy.deployer import ServerDeployer
ServerDeployer.register_provider(cls)
@classmethod
def _convert_config(
cls, config: ServerDeploymentConfig
) -> ServerDeploymentConfig:
"""Convert a generic server deployment config into a provider specific config.
Args:
config: The generic server deployment config.
Returns:
The provider specific server deployment config.
Raises:
ServerDeploymentConfigurationError: If the configuration is not
valid.
"""
if isinstance(config, cls.CONFIG_TYPE):
return config
try:
return cls.CONFIG_TYPE(**config.dict())
except ValidationError as e:
raise ServerDeploymentConfigurationError(
f"Invalid configuration for provider {cls.TYPE.value}: {e}"
)
def deploy_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Deploy a new ZenML server.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The newly created server deployment.
Raises:
ServerDeploymentExistsError: If a deployment with the same name
already exists.
"""
try:
self._get_service(config.name)
except KeyError:
pass
else:
raise ServerDeploymentExistsError(
f"ZenML server deployment with name '{config.name}' already "
f"exists"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
service = self._create_service(config, timeout)
return self._get_deployment(service)
def update_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Update an existing ZenML server deployment.
Args:
config: The new generic server deployment configuration.
timeout: The timeout in seconds to wait until the update is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The updated server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
old_config = self._get_deployment_config(service)
if old_config == config:
logger.info(
f"The {config.name} ZenML server is already configured with "
f"the same parameters."
)
service = self._start_service(service, timeout)
else:
logger.info(f"Updating the {config.name} ZenML server.")
service = self._update_service(service, config, timeout)
return self._get_deployment(service)
def remove_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with a ZenML server deployment.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the server is
removed. If not supplied, the default timeout value specified
by the provider is used.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
logger.info(f"Removing the {config.name} ZenML server.")
self._delete_service(service, timeout)
def get_server(
self,
config: ServerDeploymentConfig,
) -> ServerDeployment:
"""Retrieve information about a ZenML server deployment.
Args:
config: The generic server deployment configuration.
Returns:
The server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
return self._get_deployment(service)
def list_servers(self) -> List[ServerDeployment]:
"""List all server deployments managed by this provider.
Returns:
The list of server deployments.
"""
return [
self._get_deployment(service) for service in self._list_services()
]
def get_server_logs(
self,
config: ServerDeploymentConfig,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs of a ZenML server.
Args:
config: The generic server deployment configuration.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
return service.get_logs(follow=follow, tail=tail)
def _get_deployment_status(
self, service: BaseService
) -> ServerDeploymentStatus:
"""Get the status of a server deployment from its service.
Args:
service: The server deployment service.
Returns:
The status of the server deployment.
"""
gc = GlobalConfiguration()
url: Optional[str] = None
if service.is_running:
# all services must have an endpoint
assert service.endpoint is not None
url = service.endpoint.status.uri
connected = (
url is not None and gc.store is not None and gc.store.url == url
)
return ServerDeploymentStatus(
url=url,
status=service.status.state,
status_message=service.status.last_error,
connected=connected,
)
def _get_deployment(self, service: BaseService) -> ServerDeployment:
"""Get the server deployment associated with a service.
Args:
service: The service.
Returns:
The server deployment.
"""
config = self._get_deployment_config(service)
return ServerDeployment(
config=config,
status=self._get_deployment_status(service),
)
@classmethod
@abstractmethod
def _get_service_configuration(
cls,
server_config: ServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service, service endpoint and endpoint monitor configuration.
"""
@abstractmethod
def _create_service(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return a service instance for a ZenML server deployment.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The service instance.
"""
@abstractmethod
def _update_service(
self,
service: BaseService,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update an existing service instance for a ZenML server deployment.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The updated service instance.
"""
@abstractmethod
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start a service instance for a ZenML server deployment.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The updated service instance.
"""
@abstractmethod
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop a service instance for a ZenML server deployment.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped. If not supplied, a default timeout value specified
by the provider implementation should be used.
Returns:
The updated service instance.
"""
@abstractmethod
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove a service instance for a ZenML server deployment.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed. If not supplied, a default timeout value specified
by the provider implementation should be used.
"""
@abstractmethod
def _get_service(self, server_name: str) -> BaseService:
"""Get the service instance associated with a ZenML server deployment.
Args:
server_name: The server deployment name.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
@abstractmethod
def _list_services(self) -> List[BaseService]:
"""Get all service instances for all deployed ZenML servers.
Returns:
A list of service instances.
"""
@abstractmethod
def _get_deployment_config(
self, service: BaseService
) -> ServerDeploymentConfig:
"""Recreate the server deployment config from a service instance.
Args:
service: The service instance.
Returns:
The server deployment config.
"""
CONFIG_TYPE (BaseModel)
pydantic-model
Generic server deployment configuration.
All server deployment configurations should inherit from this class and handle extra attributes as provider specific attributes.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
Name of the server deployment. |
provider |
ServerProviderType |
The server provider type. |
Source code in zenml/zen_server/deploy/base_provider.py
class ServerDeploymentConfig(BaseModel):
"""Generic server deployment configuration.
All server deployment configurations should inherit from this class and
handle extra attributes as provider specific attributes.
Attributes:
name: Name of the server deployment.
provider: The server provider type.
"""
name: str
provider: ServerProviderType
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Allow extra attributes to be set in the base class. The concrete
# classes are responsible for validating the attributes.
extra = "allow"
Config
Pydantic configuration class.
Source code in zenml/zen_server/deploy/base_provider.py
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Allow extra attributes to be set in the base class. The concrete
# classes are responsible for validating the attributes.
extra = "allow"
deploy_server(self, config, timeout=None)
Deploy a new ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The generic server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successful. If not supplied, the default timeout value specified by the provider is used. |
None |
Returns:
Type | Description |
---|---|
ServerDeployment |
The newly created server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentExistsError |
If a deployment with the same name already exists. |
Source code in zenml/zen_server/deploy/base_provider.py
def deploy_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Deploy a new ZenML server.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The newly created server deployment.
Raises:
ServerDeploymentExistsError: If a deployment with the same name
already exists.
"""
try:
self._get_service(config.name)
except KeyError:
pass
else:
raise ServerDeploymentExistsError(
f"ZenML server deployment with name '{config.name}' already "
f"exists"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
service = self._create_service(config, timeout)
return self._get_deployment(service)
get_server(self, config)
Retrieve information about a ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The generic server deployment configuration. |
required |
Returns:
Type | Description |
---|---|
ServerDeployment |
The server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment with the given name doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def get_server(
self,
config: ServerDeploymentConfig,
) -> ServerDeployment:
"""Retrieve information about a ZenML server deployment.
Args:
config: The generic server deployment configuration.
Returns:
The server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
return self._get_deployment(service)
get_server_logs(self, config, follow=False, tail=None)
Retrieve the logs of a ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The generic server deployment configuration. |
required |
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment with the given name doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def get_server_logs(
self,
config: ServerDeploymentConfig,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs of a ZenML server.
Args:
config: The generic server deployment configuration.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
return service.get_logs(follow=follow, tail=tail)
list_servers(self)
List all server deployments managed by this provider.
Returns:
Type | Description |
---|---|
List[zenml.zen_server.deploy.deployment.ServerDeployment] |
The list of server deployments. |
Source code in zenml/zen_server/deploy/base_provider.py
def list_servers(self) -> List[ServerDeployment]:
"""List all server deployments managed by this provider.
Returns:
The list of server deployments.
"""
return [
self._get_deployment(service) for service in self._list_services()
]
register_as_provider()
classmethod
Register the class as a server provider.
Source code in zenml/zen_server/deploy/base_provider.py
@classmethod
def register_as_provider(cls) -> None:
"""Register the class as a server provider."""
from zenml.zen_server.deploy.deployer import ServerDeployer
ServerDeployer.register_provider(cls)
remove_server(self, config, timeout=None)
Tears down and removes all resources and files associated with a ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The generic server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the server is removed. If not supplied, the default timeout value specified by the provider is used. |
None |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment with the given name doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def remove_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with a ZenML server deployment.
Args:
config: The generic server deployment configuration.
timeout: The timeout in seconds to wait until the server is
removed. If not supplied, the default timeout value specified
by the provider is used.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
logger.info(f"Removing the {config.name} ZenML server.")
self._delete_service(service, timeout)
update_server(self, config, timeout=None)
Update an existing ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The new generic server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the update is successful. If not supplied, the default timeout value specified by the provider is used. |
None |
Returns:
Type | Description |
---|---|
ServerDeployment |
The updated server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If a deployment with the given name doesn't exist. |
Source code in zenml/zen_server/deploy/base_provider.py
def update_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Update an existing ZenML server deployment.
Args:
config: The new generic server deployment configuration.
timeout: The timeout in seconds to wait until the update is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The updated server deployment.
Raises:
ServerDeploymentNotFoundError: If a deployment with the given name
doesn't exist.
"""
try:
service = self._get_service(config.name)
except KeyError:
raise ServerDeploymentNotFoundError(
f"ZenML server deployment with name '{config.name}' was not "
f"found"
)
# convert the generic deployment config to a provider specific
# deployment config
config = self._convert_config(config)
old_config = self._get_deployment_config(service)
if old_config == config:
logger.info(
f"The {config.name} ZenML server is already configured with "
f"the same parameters."
)
service = self._start_service(service, timeout)
else:
logger.info(f"Updating the {config.name} ZenML server.")
service = self._update_service(service, config, timeout)
return self._get_deployment(service)
deployer
ZenML server deployer singleton implementation.
ServerDeployer
Server deployer singleton.
This class is responsible for managing the various server provider implementations and for directing server deployment lifecycle requests to the responsible provider. It acts as a facade built on top of the various server providers.
Source code in zenml/zen_server/deploy/deployer.py
class ServerDeployer(metaclass=SingletonMetaClass):
"""Server deployer singleton.
This class is responsible for managing the various server provider
implementations and for directing server deployment lifecycle requests to
the responsible provider. It acts as a facade built on top of the various
server providers.
"""
_providers: ClassVar[Dict[ServerProviderType, BaseServerProvider]] = {}
@classmethod
def register_provider(cls, provider: Type[BaseServerProvider]) -> None:
"""Register a server provider.
Args:
provider: The server provider to register.
Raises:
TypeError: If a provider with the same type is already registered.
"""
if provider.TYPE in cls._providers:
raise TypeError(
f"Server provider '{provider.TYPE}' is already registered."
)
logger.debug(f"Registering server provider '{provider.TYPE}'.")
cls._providers[provider.TYPE] = provider()
@classmethod
def get_provider(
cls, provider_type: ServerProviderType
) -> BaseServerProvider:
"""Get the server provider associated with a provider type.
Args:
provider_type: The server provider type.
Returns:
The server provider associated with the provider type.
Raises:
ServerProviderNotFoundError: If no provider is registered for the
given provider type.
"""
if provider_type not in cls._providers:
raise ServerProviderNotFoundError(
f"Server provider '{provider_type}' is not registered."
)
return cls._providers[provider_type]
def deploy_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Deploy a new ZenML server or update an existing deployment.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The server deployment.
"""
# We do this here to ensure that the zenml store is always initialized
# before the server is deployed. This is necessary because the server
# may require access to the local store configuration or database.
gc = GlobalConfiguration()
_ = gc.zen_store
try:
self.get_server(config.name)
except ServerDeploymentNotFoundError:
pass
else:
return self.update_server(config=config, timeout=timeout)
provider_name = config.provider.value
provider = self.get_provider(config.provider)
logger.info(
f"Deploying a {provider_name} ZenML server with name "
f"'{config.name}'."
)
return provider.deploy_server(config, timeout=timeout)
def update_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Update an existing ZenML server deployment.
Args:
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, a default timeout value of 30
seconds is used.
Returns:
The updated server deployment.
Raises:
ServerDeploymentExistsError: If an existing deployment with the same
name but a different provider type is found.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
existing_server = self.get_server(config.name)
provider = self.get_provider(config.provider)
existing_provider = existing_server.config.provider
if existing_provider != config.provider:
raise ServerDeploymentExistsError(
f"A server deployment with the same name '{config.name}' but "
f"with a different provider '{existing_provider.value}'."
f"is already provisioned. Please choose a different name or "
f"tear down the existing deployment."
)
return provider.update_server(config, timeout=timeout)
def remove_server(
self,
server_name: str,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with a ZenML server deployment.
Args:
server_name: The server deployment name.
timeout: The timeout in seconds to wait until the deployment is
successfully torn down. If not supplied, a provider specific
default timeout value is used.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
if self.is_connected_to_server(server_name):
self.disconnect_from_server(server_name)
logger.info(
f"Tearing down the '{server_name}' {provider_name} ZenML server."
)
provider.remove_server(server.config, timeout=timeout)
def is_connected_to_server(self, server_name: str) -> bool:
"""Check if the ZenML client is currently connected to a ZenML server.
Args:
server_name: The server deployment name.
Returns:
True if the ZenML client is connected to the ZenML server, False
otherwise.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
gc = GlobalConfiguration()
return (
server.status is not None
and server.status.url is not None
and gc.store is not None
and gc.store.url == server.status.url
)
def connect_to_server(
self,
server_name: str,
username: str,
password: str,
verify_ssl: Union[bool, str] = True,
) -> None:
"""Connect to a ZenML server instance.
Args:
server_name: The server deployment name.
username: The username to use to connect to the server.
password: The password to use to connect to the server.
verify_ssl: Either a boolean, in which case it controls whether we
verify the server's TLS certificate, or a string, in which case
it must be a path to a CA bundle to use or the CA bundle value
itself.
Raises:
ServerDeploymentError: If the ZenML server is not running or
is unreachable.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
gc = GlobalConfiguration()
if not server.status or not server.status.url:
raise ServerDeploymentError(
f"The {provider_name} {server_name} ZenML "
f"server is not currently running or is unreachable."
)
store_config = RestZenStoreConfiguration(
url=server.status.url,
username=username,
password=password,
verify_ssl=verify_ssl,
secrets_store=RestSecretsStoreConfiguration(),
)
if gc.store == store_config:
logger.info(
f"ZenML is already connected to the '{server_name}' "
f"{provider_name} ZenML server."
)
return
logger.info(
f"Connecting ZenML to the '{server_name}' "
f"{provider_name} ZenML server ({store_config.url})."
)
gc.set_store(store_config)
logger.info(
f"Connected ZenML to the '{server_name}' "
f"{provider_name} ZenML server ({store_config.url})."
)
def disconnect_from_server(
self,
server_name: Optional[str] = None,
) -> None:
"""Disconnect from a ZenML server instance.
Args:
server_name: The server deployment name. If supplied, the deployer
will check if the ZenML client is indeed connected to the server
and disconnect only if that is the case. Otherwise the deployer
will disconnect from any ZenML server.
"""
gc = GlobalConfiguration()
if not gc.store or gc.store.type != StoreType.REST:
logger.info("ZenML is not currently connected to a ZenML server.")
return
if server_name:
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
if not self.is_connected_to_server(server_name):
logger.info(
f"ZenML is not currently connected to the '{server_name}' "
f"{provider_name} ZenML server."
)
return
logger.info(
f"Disconnecting ZenML from the '{server_name}' "
f"{provider_name} ZenML server ({gc.store.url})."
)
else:
logger.info(
f"Disconnecting ZenML from the {gc.store.url} ZenML server."
)
gc.set_default_store()
logger.info("Disconnected ZenML from the ZenML server.")
def get_server(
self,
server_name: str,
) -> ServerDeployment:
"""Get a server deployment.
Args:
server_name: The server deployment name.
Returns:
The requested server deployment.
Raises:
ServerDeploymentNotFoundError: If no server deployment with the
given name is found.
"""
for provider in self._providers.values():
try:
return provider.get_server(
ServerDeploymentConfig(
name=server_name, provider=provider.TYPE
)
)
except ServerDeploymentNotFoundError:
pass
raise ServerDeploymentNotFoundError(
f"Server deployment '{server_name}' not found."
)
def list_servers(
self,
server_name: Optional[str] = None,
provider_type: Optional[ServerProviderType] = None,
) -> List[ServerDeployment]:
"""List all server deployments.
Args:
server_name: The server deployment name to filter by.
provider_type: The server provider type to filter by.
Returns:
The list of server deployments.
"""
providers: List[BaseServerProvider] = []
if provider_type:
providers = [self.get_provider(provider_type)]
else:
providers = list(self._providers.values())
servers: List[ServerDeployment] = []
for provider in providers:
if server_name:
try:
servers.append(
provider.get_server(
ServerDeploymentConfig(
name=server_name,
provider=provider.TYPE,
)
)
)
except ServerDeploymentNotFoundError:
pass
else:
servers.extend(provider.list_servers())
return servers
def get_server_logs(
self,
server_name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs of a ZenML server.
Args:
server_name: The server deployment name.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
logger.info(
f"Fetching logs from the '{server_name}' {provider_name} ZenML "
f"server..."
)
return provider.get_server_logs(
server.config, follow=follow, tail=tail
)
connect_to_server(self, server_name, username, password, verify_ssl=True)
Connect to a ZenML server instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
str |
The server deployment name. |
required |
username |
str |
The username to use to connect to the server. |
required |
password |
str |
The password to use to connect to the server. |
required |
verify_ssl |
Union[bool, str] |
Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use or the CA bundle value itself. |
True |
Exceptions:
Type | Description |
---|---|
ServerDeploymentError |
If the ZenML server is not running or is unreachable. |
Source code in zenml/zen_server/deploy/deployer.py
def connect_to_server(
self,
server_name: str,
username: str,
password: str,
verify_ssl: Union[bool, str] = True,
) -> None:
"""Connect to a ZenML server instance.
Args:
server_name: The server deployment name.
username: The username to use to connect to the server.
password: The password to use to connect to the server.
verify_ssl: Either a boolean, in which case it controls whether we
verify the server's TLS certificate, or a string, in which case
it must be a path to a CA bundle to use or the CA bundle value
itself.
Raises:
ServerDeploymentError: If the ZenML server is not running or
is unreachable.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
gc = GlobalConfiguration()
if not server.status or not server.status.url:
raise ServerDeploymentError(
f"The {provider_name} {server_name} ZenML "
f"server is not currently running or is unreachable."
)
store_config = RestZenStoreConfiguration(
url=server.status.url,
username=username,
password=password,
verify_ssl=verify_ssl,
secrets_store=RestSecretsStoreConfiguration(),
)
if gc.store == store_config:
logger.info(
f"ZenML is already connected to the '{server_name}' "
f"{provider_name} ZenML server."
)
return
logger.info(
f"Connecting ZenML to the '{server_name}' "
f"{provider_name} ZenML server ({store_config.url})."
)
gc.set_store(store_config)
logger.info(
f"Connected ZenML to the '{server_name}' "
f"{provider_name} ZenML server ({store_config.url})."
)
deploy_server(self, config, timeout=None)
Deploy a new ZenML server or update an existing deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successful. If not supplied, the default timeout value specified by the provider is used. |
None |
Returns:
Type | Description |
---|---|
ServerDeployment |
The server deployment. |
Source code in zenml/zen_server/deploy/deployer.py
def deploy_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Deploy a new ZenML server or update an existing deployment.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, the default timeout value specified
by the provider is used.
Returns:
The server deployment.
"""
# We do this here to ensure that the zenml store is always initialized
# before the server is deployed. This is necessary because the server
# may require access to the local store configuration or database.
gc = GlobalConfiguration()
_ = gc.zen_store
try:
self.get_server(config.name)
except ServerDeploymentNotFoundError:
pass
else:
return self.update_server(config=config, timeout=timeout)
provider_name = config.provider.value
provider = self.get_provider(config.provider)
logger.info(
f"Deploying a {provider_name} ZenML server with name "
f"'{config.name}'."
)
return provider.deploy_server(config, timeout=timeout)
disconnect_from_server(self, server_name=None)
Disconnect from a ZenML server instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
Optional[str] |
The server deployment name. If supplied, the deployer will check if the ZenML client is indeed connected to the server and disconnect only if that is the case. Otherwise the deployer will disconnect from any ZenML server. |
None |
Source code in zenml/zen_server/deploy/deployer.py
def disconnect_from_server(
self,
server_name: Optional[str] = None,
) -> None:
"""Disconnect from a ZenML server instance.
Args:
server_name: The server deployment name. If supplied, the deployer
will check if the ZenML client is indeed connected to the server
and disconnect only if that is the case. Otherwise the deployer
will disconnect from any ZenML server.
"""
gc = GlobalConfiguration()
if not gc.store or gc.store.type != StoreType.REST:
logger.info("ZenML is not currently connected to a ZenML server.")
return
if server_name:
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
if not self.is_connected_to_server(server_name):
logger.info(
f"ZenML is not currently connected to the '{server_name}' "
f"{provider_name} ZenML server."
)
return
logger.info(
f"Disconnecting ZenML from the '{server_name}' "
f"{provider_name} ZenML server ({gc.store.url})."
)
else:
logger.info(
f"Disconnecting ZenML from the {gc.store.url} ZenML server."
)
gc.set_default_store()
logger.info("Disconnected ZenML from the ZenML server.")
get_provider(provider_type)
classmethod
Get the server provider associated with a provider type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
provider_type |
ServerProviderType |
The server provider type. |
required |
Returns:
Type | Description |
---|---|
BaseServerProvider |
The server provider associated with the provider type. |
Exceptions:
Type | Description |
---|---|
ServerProviderNotFoundError |
If no provider is registered for the given provider type. |
Source code in zenml/zen_server/deploy/deployer.py
@classmethod
def get_provider(
cls, provider_type: ServerProviderType
) -> BaseServerProvider:
"""Get the server provider associated with a provider type.
Args:
provider_type: The server provider type.
Returns:
The server provider associated with the provider type.
Raises:
ServerProviderNotFoundError: If no provider is registered for the
given provider type.
"""
if provider_type not in cls._providers:
raise ServerProviderNotFoundError(
f"Server provider '{provider_type}' is not registered."
)
return cls._providers[provider_type]
get_server(self, server_name)
Get a server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
str |
The server deployment name. |
required |
Returns:
Type | Description |
---|---|
ServerDeployment |
The requested server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentNotFoundError |
If no server deployment with the given name is found. |
Source code in zenml/zen_server/deploy/deployer.py
def get_server(
self,
server_name: str,
) -> ServerDeployment:
"""Get a server deployment.
Args:
server_name: The server deployment name.
Returns:
The requested server deployment.
Raises:
ServerDeploymentNotFoundError: If no server deployment with the
given name is found.
"""
for provider in self._providers.values():
try:
return provider.get_server(
ServerDeploymentConfig(
name=server_name, provider=provider.TYPE
)
)
except ServerDeploymentNotFoundError:
pass
raise ServerDeploymentNotFoundError(
f"Server deployment '{server_name}' not found."
)
get_server_logs(self, server_name, follow=False, tail=None)
Retrieve the logs of a ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
str |
The server deployment name. |
required |
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Returns:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/zen_server/deploy/deployer.py
def get_server_logs(
self,
server_name: str,
follow: bool = False,
tail: Optional[int] = None,
) -> Generator[str, bool, None]:
"""Retrieve the logs of a ZenML server.
Args:
server_name: The server deployment name.
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Returns:
A generator that can be accessed to get the service logs.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
logger.info(
f"Fetching logs from the '{server_name}' {provider_name} ZenML "
f"server..."
)
return provider.get_server_logs(
server.config, follow=follow, tail=tail
)
is_connected_to_server(self, server_name)
Check if the ZenML client is currently connected to a ZenML server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
str |
The server deployment name. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the ZenML client is connected to the ZenML server, False otherwise. |
Source code in zenml/zen_server/deploy/deployer.py
def is_connected_to_server(self, server_name: str) -> bool:
"""Check if the ZenML client is currently connected to a ZenML server.
Args:
server_name: The server deployment name.
Returns:
True if the ZenML client is connected to the ZenML server, False
otherwise.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
gc = GlobalConfiguration()
return (
server.status is not None
and server.status.url is not None
and gc.store is not None
and gc.store.url == server.status.url
)
list_servers(self, server_name=None, provider_type=None)
List all server deployments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
Optional[str] |
The server deployment name to filter by. |
None |
provider_type |
Optional[zenml.enums.ServerProviderType] |
The server provider type to filter by. |
None |
Returns:
Type | Description |
---|---|
List[zenml.zen_server.deploy.deployment.ServerDeployment] |
The list of server deployments. |
Source code in zenml/zen_server/deploy/deployer.py
def list_servers(
self,
server_name: Optional[str] = None,
provider_type: Optional[ServerProviderType] = None,
) -> List[ServerDeployment]:
"""List all server deployments.
Args:
server_name: The server deployment name to filter by.
provider_type: The server provider type to filter by.
Returns:
The list of server deployments.
"""
providers: List[BaseServerProvider] = []
if provider_type:
providers = [self.get_provider(provider_type)]
else:
providers = list(self._providers.values())
servers: List[ServerDeployment] = []
for provider in providers:
if server_name:
try:
servers.append(
provider.get_server(
ServerDeploymentConfig(
name=server_name,
provider=provider.TYPE,
)
)
)
except ServerDeploymentNotFoundError:
pass
else:
servers.extend(provider.list_servers())
return servers
register_provider(provider)
classmethod
Register a server provider.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
provider |
Type[zenml.zen_server.deploy.base_provider.BaseServerProvider] |
The server provider to register. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If a provider with the same type is already registered. |
Source code in zenml/zen_server/deploy/deployer.py
@classmethod
def register_provider(cls, provider: Type[BaseServerProvider]) -> None:
"""Register a server provider.
Args:
provider: The server provider to register.
Raises:
TypeError: If a provider with the same type is already registered.
"""
if provider.TYPE in cls._providers:
raise TypeError(
f"Server provider '{provider.TYPE}' is already registered."
)
logger.debug(f"Registering server provider '{provider.TYPE}'.")
cls._providers[provider.TYPE] = provider()
remove_server(self, server_name, timeout=None)
Tears down and removes all resources and files associated with a ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_name |
str |
The server deployment name. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successfully torn down. If not supplied, a provider specific default timeout value is used. |
None |
Source code in zenml/zen_server/deploy/deployer.py
def remove_server(
self,
server_name: str,
timeout: Optional[int] = None,
) -> None:
"""Tears down and removes all resources and files associated with a ZenML server deployment.
Args:
server_name: The server deployment name.
timeout: The timeout in seconds to wait until the deployment is
successfully torn down. If not supplied, a provider specific
default timeout value is used.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
server = self.get_server(server_name)
provider_name = server.config.provider.value
provider = self.get_provider(server.config.provider)
if self.is_connected_to_server(server_name):
self.disconnect_from_server(server_name)
logger.info(
f"Tearing down the '{server_name}' {provider_name} ZenML server."
)
provider.remove_server(server.config, timeout=timeout)
update_server(self, config, timeout=None)
Update an existing ZenML server deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ServerDeploymentConfig |
The new server deployment configuration. |
required |
timeout |
Optional[int] |
The timeout in seconds to wait until the deployment is successful. If not supplied, a default timeout value of 30 seconds is used. |
None |
Returns:
Type | Description |
---|---|
ServerDeployment |
The updated server deployment. |
Exceptions:
Type | Description |
---|---|
ServerDeploymentExistsError |
If an existing deployment with the same name but a different provider type is found. |
Source code in zenml/zen_server/deploy/deployer.py
def update_server(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> ServerDeployment:
"""Update an existing ZenML server deployment.
Args:
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the deployment is
successful. If not supplied, a default timeout value of 30
seconds is used.
Returns:
The updated server deployment.
Raises:
ServerDeploymentExistsError: If an existing deployment with the same
name but a different provider type is found.
"""
# this will also raise ServerDeploymentNotFoundError if the server
# does not exist
existing_server = self.get_server(config.name)
provider = self.get_provider(config.provider)
existing_provider = existing_server.config.provider
if existing_provider != config.provider:
raise ServerDeploymentExistsError(
f"A server deployment with the same name '{config.name}' but "
f"with a different provider '{existing_provider.value}'."
f"is already provisioned. Please choose a different name or "
f"tear down the existing deployment."
)
return provider.update_server(config, timeout=timeout)
deployment
Zen Server deployment definitions.
ServerDeployment (BaseModel)
pydantic-model
Server deployment.
Attributes:
Name | Type | Description |
---|---|---|
config |
ServerDeploymentConfig |
The server deployment configuration. |
status |
Optional[zenml.zen_server.deploy.deployment.ServerDeploymentStatus] |
The server deployment status. |
Source code in zenml/zen_server/deploy/deployment.py
class ServerDeployment(BaseModel):
"""Server deployment.
Attributes:
config: The server deployment configuration.
status: The server deployment status.
"""
config: ServerDeploymentConfig
status: Optional[ServerDeploymentStatus]
@property
def is_running(self) -> bool:
"""Check if the server is running.
Returns:
Whether the server is running.
"""
return (
self.status is not None
and self.status.status == ServiceState.ACTIVE
)
is_running: bool
property
readonly
Check if the server is running.
Returns:
Type | Description |
---|---|
bool |
Whether the server is running. |
ServerDeploymentConfig (BaseModel)
pydantic-model
Generic server deployment configuration.
All server deployment configurations should inherit from this class and handle extra attributes as provider specific attributes.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
Name of the server deployment. |
provider |
ServerProviderType |
The server provider type. |
Source code in zenml/zen_server/deploy/deployment.py
class ServerDeploymentConfig(BaseModel):
"""Generic server deployment configuration.
All server deployment configurations should inherit from this class and
handle extra attributes as provider specific attributes.
Attributes:
name: Name of the server deployment.
provider: The server provider type.
"""
name: str
provider: ServerProviderType
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Allow extra attributes to be set in the base class. The concrete
# classes are responsible for validating the attributes.
extra = "allow"
Config
Pydantic configuration class.
Source code in zenml/zen_server/deploy/deployment.py
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Allow extra attributes to be set in the base class. The concrete
# classes are responsible for validating the attributes.
extra = "allow"
ServerDeploymentStatus (BaseModel)
pydantic-model
Server deployment status.
Ideally this should convey the following information:
- whether the server's deployment is managed by this client (i.e. if
the server was deployed with
zenml up
) - for a managed deployment, the status of the deployment/tear-down, e.g. not deployed, deploying, running, deleting, deployment timeout/error, tear-down timeout/error etc.
- for an unmanaged deployment, the operational status (i.e. whether the server is reachable)
- the URL of the server
Attributes:
Name | Type | Description |
---|---|---|
status |
ServiceState |
The status of the server deployment. |
status_message |
Optional[str] |
A message describing the last status. |
connected |
bool |
Whether the client is currently connected to this server. |
url |
Optional[str] |
The URL of the server. |
Source code in zenml/zen_server/deploy/deployment.py
class ServerDeploymentStatus(BaseModel):
"""Server deployment status.
Ideally this should convey the following information:
* whether the server's deployment is managed by this client (i.e. if
the server was deployed with `zenml up`)
* for a managed deployment, the status of the deployment/tear-down, e.g.
not deployed, deploying, running, deleting, deployment timeout/error,
tear-down timeout/error etc.
* for an unmanaged deployment, the operational status (i.e. whether the
server is reachable)
* the URL of the server
Attributes:
status: The status of the server deployment.
status_message: A message describing the last status.
connected: Whether the client is currently connected to this server.
url: The URL of the server.
"""
status: ServiceState
status_message: Optional[str] = None
connected: bool
url: Optional[str] = None
ca_crt: Optional[str] = None
docker
special
ZenML Server Docker Deployment.
docker_provider
Zen Server docker deployer implementation.
DockerServerProvider (BaseServerProvider)
Docker ZenML server provider.
Source code in zenml/zen_server/deploy/docker/docker_provider.py
class DockerServerProvider(BaseServerProvider):
"""Docker ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.DOCKER
CONFIG_TYPE: ClassVar[
Type[ServerDeploymentConfig]
] = DockerServerDeploymentConfig
@classmethod
def _get_service_configuration(
cls,
server_config: ServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service, service endpoint and endpoint monitor configuration.
"""
assert isinstance(server_config, DockerServerDeploymentConfig)
return (
DockerZenServerConfig(
root_runtime_path=DockerZenServer.config_path(),
singleton=True,
image=server_config.image,
name=server_config.name,
server=server_config,
),
ContainerServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
port=server_config.port,
allocate_port=False,
),
HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
use_head_request=True,
),
)
def _create_service(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return the docker ZenML server deployment service.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The service instance.
Raises:
RuntimeError: If a docker service is already running.
"""
assert isinstance(config, DockerServerDeploymentConfig)
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service = DockerZenServer.get_service()
existing_service = DockerZenServer.get_service()
if existing_service:
raise RuntimeError(
f"A docker ZenML server with name '{existing_service.config.name}' "
f"is already running. Please stop it first before starting a "
f"new one."
)
(
service_config,
endpoint_cfg,
monitor_cfg,
) = self._get_service_configuration(config)
endpoint = ContainerServiceEndpoint(
config=endpoint_cfg,
monitor=HTTPEndpointHealthMonitor(
config=monitor_cfg,
),
)
service = DockerZenServer(config=service_config, endpoint=endpoint)
service.start(timeout=timeout)
return service
def _update_service(
self,
service: BaseService,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update the docker ZenML server deployment service.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
(
new_config,
new_endpoint_cfg,
new_monitor_cfg,
) = self._get_service_configuration(config)
assert service.endpoint
assert service.endpoint.monitor
service.stop(timeout=timeout)
(
service.config,
service.endpoint.config,
service.endpoint.monitor.config,
) = (
new_config,
new_endpoint_cfg,
new_monitor_cfg,
)
service.start(timeout=timeout)
return service
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start the docker ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service.start(timeout=timeout)
return service
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop the docker ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout=timeout)
return service
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove the docker ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed.
"""
assert isinstance(service, DockerZenServer)
if timeout is None:
timeout = DOCKER_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout)
shutil.rmtree(DockerZenServer.config_path())
def _get_service(self, server_name: str) -> BaseService:
"""Get the docker ZenML server deployment service.
Args:
server_name: The server deployment name.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
service = DockerZenServer.get_service()
if service is None:
raise KeyError("The docker ZenML server is not deployed.")
if service.config.name != server_name:
raise KeyError(
"The docker ZenML server is deployed but with a different name."
)
return service
def _list_services(self) -> List[BaseService]:
"""Get all service instances for all deployed ZenML servers.
Returns:
A list of service instances.
"""
service = DockerZenServer.get_service()
if service:
return [service]
return []
def _get_deployment_config(
self, service: BaseService
) -> ServerDeploymentConfig:
"""Recreate the server deployment configuration from a service instance.
Args:
service: The service instance.
Returns:
The server deployment configuration.
"""
server = cast(DockerZenServer, service)
return server.config.server
CONFIG_TYPE (ServerDeploymentConfig)
pydantic-model
Docker server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
image |
str |
The Docker image to use for the server. |
Source code in zenml/zen_server/deploy/docker/docker_provider.py
class DockerServerDeploymentConfig(ServerDeploymentConfig):
"""Docker server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
image: The Docker image to use for the server.
"""
port: int = 8238
image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE
store: Optional[StoreConfiguration] = None
class Config:
"""Pydantic configuration."""
extra = "forbid"
Config
Pydantic configuration.
Source code in zenml/zen_server/deploy/docker/docker_provider.py
class Config:
"""Pydantic configuration."""
extra = "forbid"
docker_zen_server
Service implementation for the ZenML docker server deployment.
DockerServerDeploymentConfig (ServerDeploymentConfig)
pydantic-model
Docker server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
image |
str |
The Docker image to use for the server. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerServerDeploymentConfig(ServerDeploymentConfig):
"""Docker server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
image: The Docker image to use for the server.
"""
port: int = 8238
image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE
store: Optional[StoreConfiguration] = None
class Config:
"""Pydantic configuration."""
extra = "forbid"
Config
Pydantic configuration.
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class Config:
"""Pydantic configuration."""
extra = "forbid"
DockerZenServer (ContainerService)
pydantic-model
Service that can be used to start a docker ZenServer.
Attributes:
Name | Type | Description |
---|---|---|
config |
DockerZenServerConfig |
service configuration |
endpoint |
ContainerServiceEndpoint |
service endpoint |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerZenServer(ContainerService):
"""Service that can be used to start a docker ZenServer.
Attributes:
config: service configuration
endpoint: service endpoint
"""
SERVICE_TYPE = ServiceType(
name="docker_zenml_server",
type="zen_server",
flavor="docker",
description="Docker ZenML server deployment",
)
config: DockerZenServerConfig
endpoint: ContainerServiceEndpoint
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the docker ZenML server files are located.
Returns:
Path to the docker ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"docker",
)
@property
def _global_config_path(self) -> str:
"""Path to the global configuration directory used by this server.
Returns:
Path to the global configuration directory used by this server.
"""
return os.path.join(
self.config_path(), SERVICE_CONTAINER_GLOBAL_CONFIG_DIR
)
def _copy_global_configuration(self) -> None:
"""Copy the global configuration to the docker ZenML server location.
The docker ZenML server global configuration is a copy of the docker
global configuration. If a store configuration is explicitly set in
the server configuration, it will be used. Otherwise, the store
configuration is set to point to the local store.
"""
gc = GlobalConfiguration()
# this creates a copy of the global configuration and saves it to the
# server configuration path. The store is set to where the default local
# store is mounted in the docker container unless a custom store
# configuration is explicitly supplied with the server configuration.
gc.copy_configuration(
config_path=self._global_config_path,
store_config=self.config.server.store,
empty_store=self.config.server.store is None,
)
@classmethod
def get_service(cls) -> Optional["DockerZenServer"]:
"""Load and return the docker ZenML server service, if present.
Returns:
The docker ZenML server service or None, if the docker server
deployment is not found.
"""
from zenml.services import ServiceRegistry
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
DockerZenServer,
ServiceRegistry().load_service_from_json(f.read()),
)
except FileNotFoundError:
return None
def _get_container_cmd(self) -> Tuple[List[str], Dict[str, str]]:
"""Get the command to run the service container.
Override the inherited method to use a ZenML global config path inside
the container that points to the global config copy instead of the
one mounted from the local host.
Returns:
Command needed to launch the docker container and the environment
variables to set, in the formats accepted by subprocess.Popen.
"""
gc = GlobalConfiguration()
cmd, env = super()._get_container_cmd()
env[ENV_ZENML_CONFIG_PATH] = os.path.join(
SERVICE_CONTAINER_PATH,
SERVICE_CONTAINER_GLOBAL_CONFIG_DIR,
)
env[ENV_ZENML_SERVER_DEPLOYMENT_TYPE] = ServerDeploymentType.DOCKER
env[ENV_ZENML_ANALYTICS_OPT_IN] = str(gc.analytics_opt_in)
# Set the local stores path to point to where the client's local stores
# path is mounted in the container. This ensures that the server's store
# configuration is initialized with the same path as the client.
env[ENV_ZENML_LOCAL_STORES_PATH] = os.path.join(
SERVICE_CONTAINER_GLOBAL_CONFIG_PATH,
LOCAL_STORES_DIRECTORY_NAME,
)
env[ENV_ZENML_DISABLE_DATABASE_MIGRATION] = "True"
return cmd, env
def provision(self) -> None:
"""Provision the service."""
self._copy_global_configuration()
super().provision()
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store and gc.store.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host="0.0.0.0", # nosec
port=self.endpoint.config.port or 8000,
log_level="info",
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
config_path()
classmethod
Path to the directory where the docker ZenML server files are located.
Returns:
Type | Description |
---|---|
str |
Path to the docker ZenML server runtime directory. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the docker ZenML server files are located.
Returns:
Path to the docker ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"docker",
)
get_service()
classmethod
Load and return the docker ZenML server service, if present.
Returns:
Type | Description |
---|---|
Optional[DockerZenServer] |
The docker ZenML server service or None, if the docker server deployment is not found. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
@classmethod
def get_service(cls) -> Optional["DockerZenServer"]:
"""Load and return the docker ZenML server service, if present.
Returns:
The docker ZenML server service or None, if the docker server
deployment is not found.
"""
from zenml.services import ServiceRegistry
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
DockerZenServer,
ServiceRegistry().load_service_from_json(f.read()),
)
except FileNotFoundError:
return None
provision(self)
Provision the service.
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def provision(self) -> None:
"""Provision the service."""
self._copy_global_configuration()
super().provision()
run(self)
Run the ZenML Server.
Exceptions:
Type | Description |
---|---|
ValueError |
if started with a global configuration that connects to another ZenML server. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store and gc.store.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host="0.0.0.0", # nosec
port=self.endpoint.config.port or 8000,
log_level="info",
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
DockerZenServerConfig (ContainerServiceConfig)
pydantic-model
Docker Zen server configuration.
Attributes:
Name | Type | Description |
---|---|---|
server |
DockerServerDeploymentConfig |
The deployment configuration. |
Source code in zenml/zen_server/deploy/docker/docker_zen_server.py
class DockerZenServerConfig(ContainerServiceConfig):
"""Docker Zen server configuration.
Attributes:
server: The deployment configuration.
"""
server: DockerServerDeploymentConfig
exceptions
ZenML server deployment exceptions.
ServerDeploymentConfigurationError (ServerDeploymentError)
Raised when there is a ZenML server deployment configuration error .
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentConfigurationError(ServerDeploymentError):
"""Raised when there is a ZenML server deployment configuration error ."""
ServerDeploymentError (ZenMLBaseException)
Base exception class for all ZenML server deployment related errors.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentError(ZenMLBaseException):
"""Base exception class for all ZenML server deployment related errors."""
ServerDeploymentExistsError (ServerDeploymentError)
Raised when trying to deploy a new ZenML server with the same name.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentExistsError(ServerDeploymentError):
"""Raised when trying to deploy a new ZenML server with the same name."""
ServerDeploymentNotFoundError (ServerDeploymentError)
Raised when trying to fetch a ZenML server deployment that doesn't exist.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerDeploymentNotFoundError(ServerDeploymentError):
"""Raised when trying to fetch a ZenML server deployment that doesn't exist."""
ServerProviderNotFoundError (ServerDeploymentError)
Raised when using a ZenML server provider that doesn't exist.
Source code in zenml/zen_server/deploy/exceptions.py
class ServerProviderNotFoundError(ServerDeploymentError):
"""Raised when using a ZenML server provider that doesn't exist."""
local
special
ZenML Server Local Deployment.
local_provider
Zen Server local provider implementation.
LocalServerProvider (BaseServerProvider)
Local ZenML server provider.
Source code in zenml/zen_server/deploy/local/local_provider.py
class LocalServerProvider(BaseServerProvider):
"""Local ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.LOCAL
CONFIG_TYPE: ClassVar[
Type[ServerDeploymentConfig]
] = LocalServerDeploymentConfig
@staticmethod
def check_local_server_dependencies() -> None:
"""Check if local server dependencies are installed.
Raises:
RuntimeError: If the dependencies are not installed.
"""
try:
# Make sure the ZenML Server dependencies are installed
import fastapi # noqa
import fastapi_utils # noqa
import jwt # noqa
import multipart # noqa
import uvicorn # noqa
except ImportError:
# Unable to import the ZenML Server dependencies.
raise RuntimeError(
"The local ZenML server provider is unavailable because the "
"ZenML server requirements seems to be unavailable on your machine. "
"This is probably because ZenML was installed without the optional "
"ZenML Server dependencies. To install the missing dependencies "
f'run `pip install "zenml[server]=={__version__}"`.'
)
@classmethod
def _get_service_configuration(
cls,
server_config: ServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service, service endpoint and endpoint monitor configuration.
"""
assert isinstance(server_config, LocalServerDeploymentConfig)
return (
LocalZenServerConfig(
root_runtime_path=LocalZenServer.config_path(),
singleton=True,
name=server_config.name,
blocking=server_config.blocking,
server=server_config,
),
LocalDaemonServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
ip_address=str(server_config.ip_address),
port=server_config.port,
allocate_port=False,
),
HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
use_head_request=True,
),
)
def _create_service(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return the local ZenML server deployment service.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The service instance.
Raises:
RuntimeError: If a local service is already running.
"""
assert isinstance(config, LocalServerDeploymentConfig)
if timeout is None:
timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT
self.check_local_server_dependencies()
existing_service = LocalZenServer.get_service()
if existing_service:
raise RuntimeError(
f"A local ZenML server with name '{existing_service.config.name}' "
f"is already running. Please stop it first before starting a "
f"new one."
)
(
service_config,
endpoint_cfg,
monitor_cfg,
) = self._get_service_configuration(config)
endpoint = LocalDaemonServiceEndpoint(
config=endpoint_cfg,
monitor=HTTPEndpointHealthMonitor(
config=monitor_cfg,
),
)
service = LocalZenServer(config=service_config, endpoint=endpoint)
service.start(timeout=timeout)
return service
def _update_service(
self,
service: BaseService,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update the local ZenML server deployment service.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT
(
new_config,
new_endpoint_cfg,
new_monitor_cfg,
) = self._get_service_configuration(config)
assert service.endpoint
assert service.endpoint.monitor
service.stop(timeout=timeout)
(
service.config,
service.endpoint.config,
service.endpoint.monitor.config,
) = (
new_config,
new_endpoint_cfg,
new_monitor_cfg,
)
service.start(timeout=timeout)
return service
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start the local ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT
service.start(timeout=timeout)
return service
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop the local ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout=timeout)
return service
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove the local ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed.
"""
assert isinstance(service, LocalZenServer)
if timeout is None:
timeout = LOCAL_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout)
shutil.rmtree(LocalZenServer.config_path())
def _get_service(self, server_name: str) -> BaseService:
"""Get the local ZenML server deployment service.
Args:
server_name: The server deployment name.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
service = LocalZenServer.get_service()
if service is None:
raise KeyError("The local ZenML server is not deployed.")
if service.config.name != server_name:
raise KeyError(
"The local ZenML server is deployed but with a different name."
)
return service
def _list_services(self) -> List[BaseService]:
"""Get all service instances for all deployed ZenML servers.
Returns:
A list of service instances.
"""
service = LocalZenServer.get_service()
if service:
return [service]
return []
def _get_deployment_config(
self, service: BaseService
) -> ServerDeploymentConfig:
"""Recreate the server deployment configuration from a service instance.
Args:
service: The service instance.
Returns:
The server deployment configuration.
"""
server = cast(LocalZenServer, service)
return server.config.server
CONFIG_TYPE (ServerDeploymentConfig)
pydantic-model
Local server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
address |
The IP address where the server is reachable. |
|
blocking |
bool |
Run the server in blocking mode instead of using a daemon process. |
Source code in zenml/zen_server/deploy/local/local_provider.py
class LocalServerDeploymentConfig(ServerDeploymentConfig):
"""Local server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
address: The IP address where the server is reachable.
blocking: Run the server in blocking mode instead of using a daemon
process.
"""
port: int = 8237
ip_address: Union[
ipaddress.IPv4Address, ipaddress.IPv6Address
] = ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS)
blocking: bool = False
store: Optional[StoreConfiguration] = None
class Config:
"""Pydantic configuration."""
extra = "forbid"
Config
Pydantic configuration.
Source code in zenml/zen_server/deploy/local/local_provider.py
class Config:
"""Pydantic configuration."""
extra = "forbid"
check_local_server_dependencies()
staticmethod
Check if local server dependencies are installed.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the dependencies are not installed. |
Source code in zenml/zen_server/deploy/local/local_provider.py
@staticmethod
def check_local_server_dependencies() -> None:
"""Check if local server dependencies are installed.
Raises:
RuntimeError: If the dependencies are not installed.
"""
try:
# Make sure the ZenML Server dependencies are installed
import fastapi # noqa
import fastapi_utils # noqa
import jwt # noqa
import multipart # noqa
import uvicorn # noqa
except ImportError:
# Unable to import the ZenML Server dependencies.
raise RuntimeError(
"The local ZenML server provider is unavailable because the "
"ZenML server requirements seems to be unavailable on your machine. "
"This is probably because ZenML was installed without the optional "
"ZenML Server dependencies. To install the missing dependencies "
f'run `pip install "zenml[server]=={__version__}"`.'
)
local_zen_server
Local ZenML server deployment service implementation.
LocalServerDeploymentConfig (ServerDeploymentConfig)
pydantic-model
Local server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
The TCP port number where the server is accepting connections. |
address |
The IP address where the server is reachable. |
|
blocking |
bool |
Run the server in blocking mode instead of using a daemon process. |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
class LocalServerDeploymentConfig(ServerDeploymentConfig):
"""Local server deployment configuration.
Attributes:
port: The TCP port number where the server is accepting connections.
address: The IP address where the server is reachable.
blocking: Run the server in blocking mode instead of using a daemon
process.
"""
port: int = 8237
ip_address: Union[
ipaddress.IPv4Address, ipaddress.IPv6Address
] = ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS)
blocking: bool = False
store: Optional[StoreConfiguration] = None
class Config:
"""Pydantic configuration."""
extra = "forbid"
Config
Pydantic configuration.
Source code in zenml/zen_server/deploy/local/local_zen_server.py
class Config:
"""Pydantic configuration."""
extra = "forbid"
LocalZenServer (LocalDaemonService)
pydantic-model
Service daemon that can be used to start a local ZenML Server.
Attributes:
Name | Type | Description |
---|---|---|
config |
LocalZenServerConfig |
service configuration |
endpoint |
LocalDaemonServiceEndpoint |
optional service endpoint |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
class LocalZenServer(LocalDaemonService):
"""Service daemon that can be used to start a local ZenML Server.
Attributes:
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="local_zenml_server",
type="zen_server",
flavor="local",
description="Local ZenML server deployment",
)
config: LocalZenServerConfig
endpoint: LocalDaemonServiceEndpoint
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the local ZenML server files are located.
Returns:
Path to the local ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"local",
)
@property
def _global_config_path(self) -> str:
"""Path to the global configuration directory used by this server.
Returns:
Path to the global configuration directory used by this server.
"""
return os.path.join(self.config_path(), ".zenconfig")
def _copy_global_configuration(self) -> None:
"""Copy the global configuration to the local ZenML server location.
The local ZenML server global configuration is a copy of the local
global configuration. If a store configuration is explicitly set in
the server configuration, it will be used. Otherwise, the store
configuration is set to point to the local store.
"""
gc = GlobalConfiguration()
# this creates a copy of the global configuration and saves it to
# the server configuration path. The store is set to point to the local
# default database unless a custom store configuration is explicitly
# supplied with the server configuration.
gc.copy_configuration(
config_path=self._global_config_path,
store_config=self.config.server.store,
empty_store=self.config.server.store is None,
)
@classmethod
def get_service(cls) -> Optional["LocalZenServer"]:
"""Load and return the local ZenML server service, if present.
Returns:
The local ZenML server service or None, if the local server
deployment is not found.
"""
from zenml.services import ServiceRegistry
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
LocalZenServer,
ServiceRegistry().load_service_from_json(f.read()),
)
except FileNotFoundError:
return None
def _get_daemon_cmd(self) -> Tuple[List[str], Dict[str, str]]:
"""Get the command to start the daemon.
Overrides the base class implementation to add the environment variable
that forces the ZenML server to use the copied global config.
Returns:
The command to start the daemon and the environment variables to
set for the command.
"""
cmd, env = super()._get_daemon_cmd()
env[ENV_ZENML_CONFIG_PATH] = self._global_config_path
env[ENV_ZENML_SERVER_DEPLOYMENT_TYPE] = ServerDeploymentType.LOCAL
# Set the local stores path to the same path used by the client. This
# ensures that the server's store configuration is initialized with
# the same path as the client.
env[
ENV_ZENML_LOCAL_STORES_PATH
] = GlobalConfiguration().local_stores_path
env[ENV_ZENML_DISABLE_DATABASE_MIGRATION] = "True"
return cmd, env
def provision(self) -> None:
"""Provision the service."""
self._copy_global_configuration()
super().provision()
def start(self, timeout: int = 0) -> None:
"""Start the service and optionally wait for it to become active.
Args:
timeout: amount of time to wait for the service to become active.
If set to 0, the method will return immediately after checking
the service status.
"""
if not self.config.blocking:
super().start(timeout)
else:
self._copy_global_configuration()
local_stores_path = GlobalConfiguration().local_stores_path
GlobalConfiguration._reset_instance()
Client._reset_instance()
config_path = os.environ.get(ENV_ZENML_CONFIG_PATH)
os.environ[ENV_ZENML_CONFIG_PATH] = self._global_config_path
os.environ[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
try:
self.run()
finally:
if config_path:
os.environ[ENV_ZENML_CONFIG_PATH] = config_path
else:
del os.environ[ENV_ZENML_CONFIG_PATH]
del os.environ[ENV_ZENML_LOCAL_STORES_PATH]
GlobalConfiguration._reset_instance()
Client._reset_instance()
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store and gc.store.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host=self.endpoint.config.ip_address,
port=self.endpoint.config.port or 8000,
log_level="info",
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
config_path()
classmethod
Path to the directory where the local ZenML server files are located.
Returns:
Type | Description |
---|---|
str |
Path to the local ZenML server runtime directory. |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
@classmethod
def config_path(cls) -> str:
"""Path to the directory where the local ZenML server files are located.
Returns:
Path to the local ZenML server runtime directory.
"""
return os.path.join(
get_global_config_directory(),
"zen_server",
"local",
)
get_service()
classmethod
Load and return the local ZenML server service, if present.
Returns:
Type | Description |
---|---|
Optional[LocalZenServer] |
The local ZenML server service or None, if the local server deployment is not found. |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
@classmethod
def get_service(cls) -> Optional["LocalZenServer"]:
"""Load and return the local ZenML server service, if present.
Returns:
The local ZenML server service or None, if the local server
deployment is not found.
"""
from zenml.services import ServiceRegistry
config_filename = os.path.join(cls.config_path(), "service.json")
try:
with open(config_filename, "r") as f:
return cast(
LocalZenServer,
ServiceRegistry().load_service_from_json(f.read()),
)
except FileNotFoundError:
return None
provision(self)
Provision the service.
Source code in zenml/zen_server/deploy/local/local_zen_server.py
def provision(self) -> None:
"""Provision the service."""
self._copy_global_configuration()
super().provision()
run(self)
Run the ZenML Server.
Exceptions:
Type | Description |
---|---|
ValueError |
if started with a global configuration that connects to another ZenML server. |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
def run(self) -> None:
"""Run the ZenML Server.
Raises:
ValueError: if started with a global configuration that connects to
another ZenML server.
"""
import uvicorn
gc = GlobalConfiguration()
if gc.store and gc.store.type == StoreType.REST:
raise ValueError(
"The ZenML server cannot be started with REST store type."
)
logger.info(
"Starting ZenML Server as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
uvicorn.run(
ZEN_SERVER_ENTRYPOINT,
host=self.endpoint.config.ip_address,
port=self.endpoint.config.port or 8000,
log_level="info",
)
except KeyboardInterrupt:
logger.info("ZenML Server stopped. Resuming normal execution.")
start(self, timeout=0)
Start the service and optionally wait for it to become active.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timeout |
int |
amount of time to wait for the service to become active. If set to 0, the method will return immediately after checking the service status. |
0 |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
def start(self, timeout: int = 0) -> None:
"""Start the service and optionally wait for it to become active.
Args:
timeout: amount of time to wait for the service to become active.
If set to 0, the method will return immediately after checking
the service status.
"""
if not self.config.blocking:
super().start(timeout)
else:
self._copy_global_configuration()
local_stores_path = GlobalConfiguration().local_stores_path
GlobalConfiguration._reset_instance()
Client._reset_instance()
config_path = os.environ.get(ENV_ZENML_CONFIG_PATH)
os.environ[ENV_ZENML_CONFIG_PATH] = self._global_config_path
os.environ[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
try:
self.run()
finally:
if config_path:
os.environ[ENV_ZENML_CONFIG_PATH] = config_path
else:
del os.environ[ENV_ZENML_CONFIG_PATH]
del os.environ[ENV_ZENML_LOCAL_STORES_PATH]
GlobalConfiguration._reset_instance()
Client._reset_instance()
LocalZenServerConfig (LocalDaemonServiceConfig)
pydantic-model
Local Zen server configuration.
Attributes:
Name | Type | Description |
---|---|---|
server |
LocalServerDeploymentConfig |
The deployment configuration. |
Source code in zenml/zen_server/deploy/local/local_zen_server.py
class LocalZenServerConfig(LocalDaemonServiceConfig):
"""Local Zen server configuration.
Attributes:
server: The deployment configuration.
"""
server: LocalServerDeploymentConfig
terraform
special
ZenML Server Terraform Deployment.
providers
special
ZenML Server Terraform Providers.
aws_provider
Zen Server AWS Terraform deployer implementation.
AWSServerDeploymentConfig (TerraformServerDeploymentConfig)
pydantic-model
AWS server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
region |
str |
The AWS region to deploy to. |
rds_name |
str |
The name of the RDS instance to create |
db_name |
str |
Name of RDS database to create. |
db_type |
str |
Type of RDS database to create. |
db_version |
str |
Version of RDS database to create. |
db_instance_class |
str |
Instance class of RDS database to create. |
db_allocated_storage |
int |
Allocated storage of RDS database to create. |
Source code in zenml/zen_server/deploy/terraform/providers/aws_provider.py
class AWSServerDeploymentConfig(TerraformServerDeploymentConfig):
"""AWS server deployment configuration.
Attributes:
region: The AWS region to deploy to.
rds_name: The name of the RDS instance to create
db_name: Name of RDS database to create.
db_type: Type of RDS database to create.
db_version: Version of RDS database to create.
db_instance_class: Instance class of RDS database to create.
db_allocated_storage: Allocated storage of RDS database to create.
"""
region: str = "eu-west-1"
rds_name: str = "zenmlserver"
db_name: str = "zenmlserver"
db_type: str = "mysql"
db_version: str = "5.7.38"
db_instance_class: str = "db.t3.micro"
db_allocated_storage: int = 5
AWSServerProvider (TerraformServerProvider)
AWS ZenML server provider.
Source code in zenml/zen_server/deploy/terraform/providers/aws_provider.py
class AWSServerProvider(TerraformServerProvider):
"""AWS ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.AWS
CONFIG_TYPE: ClassVar[
Type[TerraformServerDeploymentConfig]
] = AWSServerDeploymentConfig
CONFIG_TYPE (TerraformServerDeploymentConfig)
pydantic-model
AWS server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
region |
str |
The AWS region to deploy to. |
rds_name |
str |
The name of the RDS instance to create |
db_name |
str |
Name of RDS database to create. |
db_type |
str |
Type of RDS database to create. |
db_version |
str |
Version of RDS database to create. |
db_instance_class |
str |
Instance class of RDS database to create. |
db_allocated_storage |
int |
Allocated storage of RDS database to create. |
Source code in zenml/zen_server/deploy/terraform/providers/aws_provider.py
class AWSServerDeploymentConfig(TerraformServerDeploymentConfig):
"""AWS server deployment configuration.
Attributes:
region: The AWS region to deploy to.
rds_name: The name of the RDS instance to create
db_name: Name of RDS database to create.
db_type: Type of RDS database to create.
db_version: Version of RDS database to create.
db_instance_class: Instance class of RDS database to create.
db_allocated_storage: Allocated storage of RDS database to create.
"""
region: str = "eu-west-1"
rds_name: str = "zenmlserver"
db_name: str = "zenmlserver"
db_type: str = "mysql"
db_version: str = "5.7.38"
db_instance_class: str = "db.t3.micro"
db_allocated_storage: int = 5
azure_provider
Zen Server Azure Terraform deployer implementation.
AzureServerDeploymentConfig (TerraformServerDeploymentConfig)
pydantic-model
Azure server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
resource_group |
str |
The Azure resource_group to deploy to. |
db_instance_name |
str |
The name of the Flexible MySQL instance to create |
db_name |
str |
Name of RDS database to create. |
db_version |
str |
Version of MySQL database to create. |
db_sku_name |
str |
The sku_name for the database resource. |
db_disk_size |
int |
Allocated storage of MySQL database to create. |
Source code in zenml/zen_server/deploy/terraform/providers/azure_provider.py
class AzureServerDeploymentConfig(TerraformServerDeploymentConfig):
"""Azure server deployment configuration.
Attributes:
resource_group: The Azure resource_group to deploy to.
db_instance_name: The name of the Flexible MySQL instance to create
db_name: Name of RDS database to create.
db_version: Version of MySQL database to create.
db_sku_name: The sku_name for the database resource.
db_disk_size: Allocated storage of MySQL database to create.
"""
resource_group: str = "zenml"
db_instance_name: str = "zenmlserver"
db_name: str = "zenmlserver"
db_version: str = "5.7"
db_sku_name: str = "B_Standard_B1s"
db_disk_size: int = 20
AzureServerProvider (TerraformServerProvider)
Azure ZenML server provider.
Source code in zenml/zen_server/deploy/terraform/providers/azure_provider.py
class AzureServerProvider(TerraformServerProvider):
"""Azure ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.AZURE
CONFIG_TYPE: ClassVar[
Type[TerraformServerDeploymentConfig]
] = AzureServerDeploymentConfig
CONFIG_TYPE (TerraformServerDeploymentConfig)
pydantic-model
Azure server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
resource_group |
str |
The Azure resource_group to deploy to. |
db_instance_name |
str |
The name of the Flexible MySQL instance to create |
db_name |
str |
Name of RDS database to create. |
db_version |
str |
Version of MySQL database to create. |
db_sku_name |
str |
The sku_name for the database resource. |
db_disk_size |
int |
Allocated storage of MySQL database to create. |
Source code in zenml/zen_server/deploy/terraform/providers/azure_provider.py
class AzureServerDeploymentConfig(TerraformServerDeploymentConfig):
"""Azure server deployment configuration.
Attributes:
resource_group: The Azure resource_group to deploy to.
db_instance_name: The name of the Flexible MySQL instance to create
db_name: Name of RDS database to create.
db_version: Version of MySQL database to create.
db_sku_name: The sku_name for the database resource.
db_disk_size: Allocated storage of MySQL database to create.
"""
resource_group: str = "zenml"
db_instance_name: str = "zenmlserver"
db_name: str = "zenmlserver"
db_version: str = "5.7"
db_sku_name: str = "B_Standard_B1s"
db_disk_size: int = 20
gcp_provider
Zen Server GCP Terraform deployer implementation.
GCPServerDeploymentConfig (TerraformServerDeploymentConfig)
pydantic-model
GCP server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
project_id |
str |
The project in GCP to deploy the server to. |
region |
str |
The GCP region to deploy to. |
cloudsql_name |
str |
The name of the CloudSQL instance to create |
db_name |
str |
Name of CloudSQL database to create. |
db_instance_tier |
str |
Instance class of CloudSQL database to create. |
db_disk_size |
int |
Allocated storage of CloudSQL database to create. |
Source code in zenml/zen_server/deploy/terraform/providers/gcp_provider.py
class GCPServerDeploymentConfig(TerraformServerDeploymentConfig):
"""GCP server deployment configuration.
Attributes:
project_id: The project in GCP to deploy the server to.
region: The GCP region to deploy to.
cloudsql_name: The name of the CloudSQL instance to create
db_name: Name of CloudSQL database to create.
db_instance_tier: Instance class of CloudSQL database to create.
db_disk_size: Allocated storage of CloudSQL database to create.
"""
project_id: str
region: str = "europe-west3"
cloudsql_name: str = "zenmlserver"
db_name: str = "zenmlserver"
db_instance_tier: str = "db-n1-standard-1"
db_disk_size: int = 10
GCPServerProvider (TerraformServerProvider)
GCP ZenML server provider.
Source code in zenml/zen_server/deploy/terraform/providers/gcp_provider.py
class GCPServerProvider(TerraformServerProvider):
"""GCP ZenML server provider."""
TYPE: ClassVar[ServerProviderType] = ServerProviderType.GCP
CONFIG_TYPE: ClassVar[
Type[TerraformServerDeploymentConfig]
] = GCPServerDeploymentConfig
CONFIG_TYPE (TerraformServerDeploymentConfig)
pydantic-model
GCP server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
project_id |
str |
The project in GCP to deploy the server to. |
region |
str |
The GCP region to deploy to. |
cloudsql_name |
str |
The name of the CloudSQL instance to create |
db_name |
str |
Name of CloudSQL database to create. |
db_instance_tier |
str |
Instance class of CloudSQL database to create. |
db_disk_size |
int |
Allocated storage of CloudSQL database to create. |
Source code in zenml/zen_server/deploy/terraform/providers/gcp_provider.py
class GCPServerDeploymentConfig(TerraformServerDeploymentConfig):
"""GCP server deployment configuration.
Attributes:
project_id: The project in GCP to deploy the server to.
region: The GCP region to deploy to.
cloudsql_name: The name of the CloudSQL instance to create
db_name: Name of CloudSQL database to create.
db_instance_tier: Instance class of CloudSQL database to create.
db_disk_size: Allocated storage of CloudSQL database to create.
"""
project_id: str
region: str = "europe-west3"
cloudsql_name: str = "zenmlserver"
db_name: str = "zenmlserver"
db_instance_tier: str = "db-n1-standard-1"
db_disk_size: int = 10
terraform_provider
Zen Server terraform deployer implementation.
TerraformServerProvider (BaseServerProvider)
Terraform ZenML server provider.
Source code in zenml/zen_server/deploy/terraform/providers/terraform_provider.py
class TerraformServerProvider(BaseServerProvider):
"""Terraform ZenML server provider."""
CONFIG_TYPE: ClassVar[
Type[ServerDeploymentConfig]
] = TerraformServerDeploymentConfig
@staticmethod
def _get_server_recipe_root_path() -> str:
"""Get the server recipe root path.
The Terraform recipe files for all terraform server providers are
located in a folder relative to the `zenml.zen_server.deploy.terraform`
Python module.
Returns:
The server recipe root path.
"""
import zenml.zen_server.deploy.terraform as terraform_module
root_path = os.path.join(
os.path.dirname(terraform_module.__file__),
TERRAFORM_ZENML_SERVER_RECIPE_SUBPATH,
)
return root_path
@classmethod
def _get_service_configuration(
cls,
server_config: ServerDeploymentConfig,
) -> Tuple[
ServiceConfig,
ServiceEndpointConfig,
ServiceEndpointHealthMonitorConfig,
]:
"""Construct the service configuration from a server deployment configuration.
Args:
server_config: server deployment configuration.
Returns:
The service configuration.
"""
assert isinstance(server_config, TerraformServerDeploymentConfig)
return (
TerraformZenServerConfig(
name=server_config.name,
root_runtime_path=TERRAFORM_ZENML_SERVER_CONFIG_PATH,
singleton=True,
directory_path=os.path.join(
cls._get_server_recipe_root_path(),
server_config.provider,
),
log_level=server_config.log_level,
variables_file_path=TERRAFORM_VALUES_FILE_PATH,
server=server_config,
),
ServiceEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
allocate_port=False,
),
HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=ZEN_SERVER_HEALTHCHECK_URL_PATH,
use_head_request=True,
),
)
def _create_service(
self,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Create, start and return the terraform ZenML server deployment service.
Args:
config: The server deployment configuration.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The service instance.
Raises:
RuntimeError: If a terraform service is already running.
"""
assert isinstance(config, TerraformServerDeploymentConfig)
if timeout is None:
timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT
existing_service = TerraformZenServer.get_service()
if existing_service:
raise RuntimeError(
f"A terraform ZenML server with name '{existing_service.config.name}' "
f"is already running. Please stop it first before starting a "
f"new one."
)
(
service_config,
endpoint_cfg,
monitor_cfg,
) = self._get_service_configuration(config)
service = TerraformZenServer(config=service_config)
service.start(timeout=timeout)
return service
def _update_service(
self,
service: BaseService,
config: ServerDeploymentConfig,
timeout: Optional[int] = None,
) -> BaseService:
"""Update the terraform ZenML server deployment service.
Args:
service: The service instance.
config: The new server deployment configuration.
timeout: The timeout in seconds to wait until the updated service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT
(
new_config,
endpoint_cfg,
monitor_cfg,
) = self._get_service_configuration(config)
assert isinstance(new_config, TerraformZenServerConfig)
assert isinstance(service, TerraformZenServer)
# preserve the server ID across updates
service.config = new_config
service.start(timeout=timeout)
return service
def _start_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Start the terraform ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
running.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT
service.start(timeout=timeout)
return service
def _stop_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> BaseService:
"""Stop the terraform ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
stopped.
Returns:
The updated service instance.
"""
if timeout is None:
timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout=timeout)
return service
def _delete_service(
self,
service: BaseService,
timeout: Optional[int] = None,
) -> None:
"""Remove the terraform ZenML server deployment service.
Args:
service: The service instance.
timeout: The timeout in seconds to wait until the service is
removed.
"""
assert isinstance(service, TerraformZenServer)
if timeout is None:
timeout = TERRAFORM_ZENML_SERVER_DEFAULT_TIMEOUT
service.stop(timeout)
def _get_service(self, server_name: str) -> BaseService:
"""Get the terraform ZenML server deployment service.
Args:
server_name: The server deployment name.
Returns:
The service instance.
Raises:
KeyError: If the server deployment is not found.
"""
service = TerraformZenServer.get_service()
if service is None:
raise KeyError("The terraform ZenML server is not deployed.")
if service.config.server.name != server_name:
raise KeyError(
"The terraform ZenML server is deployed but with a different name."
)
return service
def _list_services(self) -> List[BaseService]:
"""Get all service instances for all deployed ZenML servers.
Returns:
A list of service instances.
"""
service = TerraformZenServer.get_service()
if service:
return [service]
return []
def _get_deployment_config(
self, service: BaseService
) -> ServerDeploymentConfig:
"""Recreate the server deployment configuration from a service instance.
Args:
service: The service instance.
Returns:
The server deployment configuration.
"""
server = cast(TerraformZenServer, service)
return server.config.server
def _get_deployment_status(
self, service: BaseService
) -> ServerDeploymentStatus:
"""Get the status of a server deployment from its service.
Args:
service: The server deployment service.
Returns:
The status of the server deployment.
"""
gc = GlobalConfiguration()
url: Optional[str] = None
service = cast(TerraformZenServer, service)
ca_crt = None
if service.is_running:
url = service.get_server_url()
ca_crt = service.get_certificate()
connected = (
url is not None and gc.store is not None and gc.store.url == url
)
return ServerDeploymentStatus(
url=url,
status=service.status.state,
status_message=service.status.last_error,
connected=connected,
ca_crt=ca_crt,
)
CONFIG_TYPE (ServerDeploymentConfig)
pydantic-model
Terraform server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
log_level |
str |
The log level to set the terraform client to. Choose one of TRACE, DEBUG, INFO, WARN or ERROR (case insensitive). |
username |
str |
The username for the default ZenML server account. |
password |
str |
The password for the default ZenML server account. |
helm_chart |
str |
The path to the ZenML server helm chart to use for deployment. |
zenmlserver_image_repo |
str |
The repository to use for the zenml server. |
zenmlserver_image_tag |
str |
The tag to use for the zenml server docker image. |
namespace |
str |
The Kubernetes namespace to deploy the ZenML server to. |
kubectl_config_path |
str |
The path to the kubectl config file to use for deployment. |
ingress_tls |
bool |
Whether to use TLS for the ingress. |
ingress_tls_generate_certs |
bool |
Whether to generate self-signed TLS certificates for the ingress. |
ingress_tls_secret_name |
str |
The name of the Kubernetes secret to use for the ingress. |
create_ingress_controller |
bool |
Whether to deploy an nginx ingress controller as part of the deployment. |
ingress_controller_ip |
str |
The ingress controller IP to use for the ingress self-signed certificate and to compute the ZenML server URL. |
deploy_db |
bool |
Whether to create a SQL database service as part of the recipe. |
database_username |
str |
The username for the database. |
database_password |
str |
The password for the database. |
database_url |
str |
The URL of the RDS instance to use for the ZenML server. |
database_ssl_ca |
str |
The path to the SSL CA certificate to use for the database connection. |
database_ssl_cert |
str |
The path to the client SSL certificate to use for the database connection. |
database_ssl_key |
str |
The path to the client SSL key to use for the database connection. |
database_ssl_verify_server_cert |
bool |
Whether to verify the database server SSL certificate. |
analytics_opt_in |
bool |
Whether to enable analytics. |
Source code in zenml/zen_server/deploy/terraform/providers/terraform_provider.py
class TerraformServerDeploymentConfig(ServerDeploymentConfig):
"""Terraform server deployment configuration.
Attributes:
log_level: The log level to set the terraform client to. Choose one of
TRACE, DEBUG, INFO, WARN or ERROR (case insensitive).
username: The username for the default ZenML server account.
password: The password for the default ZenML server account.
helm_chart: The path to the ZenML server helm chart to use for
deployment.
zenmlserver_image_repo: The repository to use for the zenml server.
zenmlserver_image_tag: The tag to use for the zenml server docker
image.
namespace: The Kubernetes namespace to deploy the ZenML server to.
kubectl_config_path: The path to the kubectl config file to use for
deployment.
ingress_tls: Whether to use TLS for the ingress.
ingress_tls_generate_certs: Whether to generate self-signed TLS
certificates for the ingress.
ingress_tls_secret_name: The name of the Kubernetes secret to use for
the ingress.
create_ingress_controller: Whether to deploy an nginx ingress
controller as part of the deployment.
ingress_controller_ip: The ingress controller IP to use for
the ingress self-signed certificate and to compute the ZenML server
URL.
deploy_db: Whether to create a SQL database service as part of the recipe.
database_username: The username for the database.
database_password: The password for the database.
database_url: The URL of the RDS instance to use for the ZenML server.
database_ssl_ca: The path to the SSL CA certificate to use for the
database connection.
database_ssl_cert: The path to the client SSL certificate to use for the
database connection.
database_ssl_key: The path to the client SSL key to use for the
database connection.
database_ssl_verify_server_cert: Whether to verify the database server
SSL certificate.
analytics_opt_in: Whether to enable analytics.
"""
log_level: str = "ERROR"
username: str
password: str
helm_chart: str = get_helm_chart_path()
zenmlserver_image_repo: str = "zenmldocker/zenml-server"
zenmlserver_image_tag: str = "latest"
namespace: str = "zenmlserver"
kubectl_config_path: str = os.path.join(
str(Path.home()), ".kube", "config"
)
ingress_tls: bool = False
ingress_tls_generate_certs: bool = True
ingress_tls_secret_name: str = "zenml-tls-certs"
create_ingress_controller: bool = True
ingress_controller_ip: str = ""
deploy_db: bool = True
database_username: str = "user"
database_password: str = ""
database_url: str = ""
database_ssl_ca: str = ""
database_ssl_cert: str = ""
database_ssl_key: str = ""
database_ssl_verify_server_cert: bool = True
analytics_opt_in: bool = True
class Config:
"""Pydantic configuration."""
extra = "allow"
Config
Pydantic configuration.
Source code in zenml/zen_server/deploy/terraform/providers/terraform_provider.py
class Config:
"""Pydantic configuration."""
extra = "allow"
terraform_zen_server
Service implementation for the ZenML terraform server deployment.
TerraformServerDeploymentConfig (ServerDeploymentConfig)
pydantic-model
Terraform server deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
log_level |
str |
The log level to set the terraform client to. Choose one of TRACE, DEBUG, INFO, WARN or ERROR (case insensitive). |
username |
str |
The username for the default ZenML server account. |
password |
str |
The password for the default ZenML server account. |
helm_chart |
str |
The path to the ZenML server helm chart to use for deployment. |
zenmlserver_image_repo |
str |
The repository to use for the zenml server. |
zenmlserver_image_tag |
str |
The tag to use for the zenml server docker image. |
namespace |
str |
The Kubernetes namespace to deploy the ZenML server to. |
kubectl_config_path |
str |
The path to the kubectl config file to use for deployment. |
ingress_tls |
bool |
Whether to use TLS for the ingress. |
ingress_tls_generate_certs |
bool |
Whether to generate self-signed TLS certificates for the ingress. |
ingress_tls_secret_name |
str |
The name of the Kubernetes secret to use for the ingress. |
create_ingress_controller |
bool |
Whether to deploy an nginx ingress controller as part of the deployment. |
ingress_controller_ip |
str |
The ingress controller IP to use for the ingress self-signed certificate and to compute the ZenML server URL. |
deploy_db |
bool |
Whether to create a SQL database service as part of the recipe. |
database_username |
str |
The username for the database. |
database_password |
str |
The password for the database. |
database_url |
str |
The URL of the RDS instance to use for the ZenML server. |
database_ssl_ca |
str |
The path to the SSL CA certificate to use for the database connection. |
database_ssl_cert |
str |
The path to the client SSL certificate to use for the database connection. |
database_ssl_key |
str |
The path to the client SSL key to use for the database connection. |
database_ssl_verify_server_cert |
bool |
Whether to verify the database server SSL certificate. |
analytics_opt_in |
bool |
Whether to enable analytics. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class TerraformServerDeploymentConfig(ServerDeploymentConfig):
"""Terraform server deployment configuration.
Attributes:
log_level: The log level to set the terraform client to. Choose one of
TRACE, DEBUG, INFO, WARN or ERROR (case insensitive).
username: The username for the default ZenML server account.
password: The password for the default ZenML server account.
helm_chart: The path to the ZenML server helm chart to use for
deployment.
zenmlserver_image_repo: The repository to use for the zenml server.
zenmlserver_image_tag: The tag to use for the zenml server docker
image.
namespace: The Kubernetes namespace to deploy the ZenML server to.
kubectl_config_path: The path to the kubectl config file to use for
deployment.
ingress_tls: Whether to use TLS for the ingress.
ingress_tls_generate_certs: Whether to generate self-signed TLS
certificates for the ingress.
ingress_tls_secret_name: The name of the Kubernetes secret to use for
the ingress.
create_ingress_controller: Whether to deploy an nginx ingress
controller as part of the deployment.
ingress_controller_ip: The ingress controller IP to use for
the ingress self-signed certificate and to compute the ZenML server
URL.
deploy_db: Whether to create a SQL database service as part of the recipe.
database_username: The username for the database.
database_password: The password for the database.
database_url: The URL of the RDS instance to use for the ZenML server.
database_ssl_ca: The path to the SSL CA certificate to use for the
database connection.
database_ssl_cert: The path to the client SSL certificate to use for the
database connection.
database_ssl_key: The path to the client SSL key to use for the
database connection.
database_ssl_verify_server_cert: Whether to verify the database server
SSL certificate.
analytics_opt_in: Whether to enable analytics.
"""
log_level: str = "ERROR"
username: str
password: str
helm_chart: str = get_helm_chart_path()
zenmlserver_image_repo: str = "zenmldocker/zenml-server"
zenmlserver_image_tag: str = "latest"
namespace: str = "zenmlserver"
kubectl_config_path: str = os.path.join(
str(Path.home()), ".kube", "config"
)
ingress_tls: bool = False
ingress_tls_generate_certs: bool = True
ingress_tls_secret_name: str = "zenml-tls-certs"
create_ingress_controller: bool = True
ingress_controller_ip: str = ""
deploy_db: bool = True
database_username: str = "user"
database_password: str = ""
database_url: str = ""
database_ssl_ca: str = ""
database_ssl_cert: str = ""
database_ssl_key: str = ""
database_ssl_verify_server_cert: bool = True
analytics_opt_in: bool = True
class Config:
"""Pydantic configuration."""
extra = "allow"
Config
Pydantic configuration.
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class Config:
"""Pydantic configuration."""
extra = "allow"
TerraformZenServer (TerraformService)
pydantic-model
Service that can be used to start a terraform ZenServer.
Attributes:
Name | Type | Description |
---|---|---|
config |
TerraformZenServerConfig |
service configuration |
endpoint |
Optional[zenml.services.service_endpoint.BaseServiceEndpoint] |
service endpoint |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class TerraformZenServer(TerraformService):
"""Service that can be used to start a terraform ZenServer.
Attributes:
config: service configuration
endpoint: service endpoint
"""
SERVICE_TYPE = ServiceType(
name="terraform_zenml_server",
type="zen_server",
flavor="terraform",
description="Terraform ZenML server deployment",
)
config: TerraformZenServerConfig
@classmethod
def get_service(cls) -> Optional["TerraformZenServer"]:
"""Load and return the terraform ZenML server service, if present.
Returns:
The terraform ZenML server service or None, if the terraform server
deployment is not found.
"""
from zenml.services import ServiceRegistry
try:
with open(TERRAFORM_ZENML_SERVER_CONFIG_FILENAME, "r") as f:
return cast(
TerraformZenServer,
ServiceRegistry().load_service_from_json(f.read()),
)
except FileNotFoundError:
return None
def get_vars(self) -> Dict[str, Any]:
"""Get variables as a dictionary.
Returns:
A dictionary of variables to use for the Terraform deployment.
"""
# get the contents of the server deployment config as dict
filter_vars = ["log_level", "provider"]
# filter keys that are not modeled as terraform deployment vars
vars = {
k: str(v) if isinstance(v, UUID) else v
for k, v in self.config.server.dict().items()
if k not in filter_vars
}
assert self.status.runtime_path
with open(
os.path.join(
self.status.runtime_path, self.config.variables_file_path
),
"w",
) as fp:
json.dump(vars, fp, indent=4)
return vars
def provision(self) -> None:
"""Provision the service."""
super().provision()
logger.info(
f"Your ZenML server is now deployed with URL:\n"
f"{self.get_server_url()}"
)
def get_server_url(self) -> str:
"""Returns the deployed ZenML server's URL.
Returns:
The URL of the deployed ZenML server.
"""
return str(
self.terraform_client.output(
TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_URL, full_value=True
)
)
def get_certificate(self) -> Optional[str]:
"""Returns the CA certificate configured for the ZenML server.
Returns:
The CA certificate configured for the ZenML server.
"""
return cast(
str,
self.terraform_client.output(
TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_CA_CRT, full_value=True
),
)
get_certificate(self)
Returns the CA certificate configured for the ZenML server.
Returns:
Type | Description |
---|---|
Optional[str] |
The CA certificate configured for the ZenML server. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_certificate(self) -> Optional[str]:
"""Returns the CA certificate configured for the ZenML server.
Returns:
The CA certificate configured for the ZenML server.
"""
return cast(
str,
self.terraform_client.output(
TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_CA_CRT, full_value=True
),
)
get_server_url(self)
Returns the deployed ZenML server's URL.
Returns:
Type | Description |
---|---|
str |
The URL of the deployed ZenML server. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_server_url(self) -> str:
"""Returns the deployed ZenML server's URL.
Returns:
The URL of the deployed ZenML server.
"""
return str(
self.terraform_client.output(
TERRAFORM_DEPLOYED_ZENSERVER_OUTPUT_URL, full_value=True
)
)
get_service()
classmethod
Load and return the terraform ZenML server service, if present.
Returns:
Type | Description |
---|---|
Optional[TerraformZenServer] |
The terraform ZenML server service or None, if the terraform server deployment is not found. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
@classmethod
def get_service(cls) -> Optional["TerraformZenServer"]:
"""Load and return the terraform ZenML server service, if present.
Returns:
The terraform ZenML server service or None, if the terraform server
deployment is not found.
"""
from zenml.services import ServiceRegistry
try:
with open(TERRAFORM_ZENML_SERVER_CONFIG_FILENAME, "r") as f:
return cast(
TerraformZenServer,
ServiceRegistry().load_service_from_json(f.read()),
)
except FileNotFoundError:
return None
get_vars(self)
Get variables as a dictionary.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary of variables to use for the Terraform deployment. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_vars(self) -> Dict[str, Any]:
"""Get variables as a dictionary.
Returns:
A dictionary of variables to use for the Terraform deployment.
"""
# get the contents of the server deployment config as dict
filter_vars = ["log_level", "provider"]
# filter keys that are not modeled as terraform deployment vars
vars = {
k: str(v) if isinstance(v, UUID) else v
for k, v in self.config.server.dict().items()
if k not in filter_vars
}
assert self.status.runtime_path
with open(
os.path.join(
self.status.runtime_path, self.config.variables_file_path
),
"w",
) as fp:
json.dump(vars, fp, indent=4)
return vars
provision(self)
Provision the service.
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def provision(self) -> None:
"""Provision the service."""
super().provision()
logger.info(
f"Your ZenML server is now deployed with URL:\n"
f"{self.get_server_url()}"
)
TerraformZenServerConfig (TerraformServiceConfig)
pydantic-model
Terraform Zen server configuration.
Attributes:
Name | Type | Description |
---|---|---|
server |
TerraformServerDeploymentConfig |
The deployment configuration. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
class TerraformZenServerConfig(TerraformServiceConfig):
"""Terraform Zen server configuration.
Attributes:
server: The deployment configuration.
"""
server: TerraformServerDeploymentConfig
copy_terraform_files: bool = True
get_helm_chart_path()
Get the ZenML server helm chart path.
The ZenML server helm chart files are located in a folder relative to the
zenml.zen_server.deploy
Python module.
Returns:
Type | Description |
---|---|
str |
The helm chart path. |
Source code in zenml/zen_server/deploy/terraform/terraform_zen_server.py
def get_helm_chart_path() -> str:
"""Get the ZenML server helm chart path.
The ZenML server helm chart files are located in a folder relative to the
`zenml.zen_server.deploy` Python module.
Returns:
The helm chart path.
"""
import zenml.zen_server.deploy as deploy_module
path = os.path.join(
os.path.dirname(deploy_module.__file__),
ZENML_HELM_CHART_SUBPATH,
)
return path
exceptions
REST API exception handling.
ErrorModel (BaseModel)
pydantic-model
Base class for error responses.
Source code in zenml/zen_server/exceptions.py
class ErrorModel(BaseModel):
"""Base class for error responses."""
detail: Any
error_detail(error, exception_type=None)
Convert an Exception to API representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error |
Exception |
Exception to convert. |
required |
exception_type |
Optional[Type[Exception]] |
Exception type to use in the error response instead of the type of the supplied exception. This is useful when the raised exception is a subclass of an exception type that is properly handled by the REST API. |
None |
Returns:
Type | Description |
---|---|
List[str] |
List of strings representing the error. |
Source code in zenml/zen_server/exceptions.py
def error_detail(
error: Exception, exception_type: Optional[Type[Exception]] = None
) -> List[str]:
"""Convert an Exception to API representation.
Args:
error: Exception to convert.
exception_type: Exception type to use in the error response instead of
the type of the supplied exception. This is useful when the raised
exception is a subclass of an exception type that is properly
handled by the REST API.
Returns:
List of strings representing the error.
"""
class_name = (
exception_type.__name__ if exception_type else type(error).__name__
)
return [class_name, str(error)]
exception_from_response(response)
Convert an error HTTP response to an exception.
Uses the REST_API_EXCEPTIONS list to determine the appropriate exception class to use based on the response status code and the exception class name embedded in the response body.
The last entry in the list of exceptions associated with a status code is used as a fallback if the exception class name in the response body is not found in the list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
response |
Response |
HTTP error response to convert. |
required |
Returns:
Type | Description |
---|---|
Optional[Exception] |
Exception with the appropriate type and arguments, or None if the response does not contain an error or the response cannot be unpacked into an exception. |
Source code in zenml/zen_server/exceptions.py
def exception_from_response(
response: requests.Response,
) -> Optional[Exception]:
"""Convert an error HTTP response to an exception.
Uses the REST_API_EXCEPTIONS list to determine the appropriate exception
class to use based on the response status code and the exception class name
embedded in the response body.
The last entry in the list of exceptions associated with a status code is
used as a fallback if the exception class name in the response body is not
found in the list.
Args:
response: HTTP error response to convert.
Returns:
Exception with the appropriate type and arguments, or None if the
response does not contain an error or the response cannot be unpacked
into an exception.
"""
def unpack_exc() -> Tuple[Optional[str], str]:
"""Unpack the response body into an exception name and message.
Returns:
Tuple of exception name and message.
"""
try:
response_json = response.json()
except requests.exceptions.JSONDecodeError:
return None, response.text
if isinstance(response_json, dict):
detail = response_json.get("detail", response.text)
else:
detail = response_json
# The detail can also be a single string
if isinstance(detail, str):
return None, detail
# The detail should be a list of strings encoding the exception
# class name and the exception message
if not isinstance(detail, list):
return None, response.text
# First detail item is the exception class name
if len(detail) < 1 or not isinstance(detail[0], str):
return None, response.text
# Remaining detail items are the exception arguments
message = ": ".join([str(arg) for arg in detail[1:]])
return detail[0], message
exc_name, exc_msg = unpack_exc()
default_exc: Optional[Type[Exception]] = None
for exception, status_code in REST_API_EXCEPTIONS:
if response.status_code != status_code:
continue
default_exc = exception
if exc_name == exception.__name__:
# An entry was found that is an exact match for both the status
# code and the exception class name.
break
else:
# The exception class name extracted from the response body was not
# found in the list of exceptions associated with the status code, so
# use the last entry as a fallback.
if default_exc is None:
return None
exception = default_exc
return exception(exc_msg)
http_exception_from_error(error)
Convert an Exception to a HTTP error response.
Uses the REST_API_EXCEPTIONS list to determine the appropriate status code associated with the exception type. The exception class name and arguments are embedded in the HTTP error response body.
The lookup uses the first occurrence of the exception type in the list. If
the exception type is not found in the list, the lookup uses isinstance
to determine the most specific exception type corresponding to the supplied
exception. This allows users to call this method with exception types that
are not directly listed in the REST_API_EXCEPTIONS list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error |
Exception |
Exception to convert. |
required |
Returns:
Type | Description |
---|---|
HTTPException |
HTTPException with the appropriate status code and error detail. |
Source code in zenml/zen_server/exceptions.py
def http_exception_from_error(error: Exception) -> "HTTPException":
"""Convert an Exception to a HTTP error response.
Uses the REST_API_EXCEPTIONS list to determine the appropriate status code
associated with the exception type. The exception class name and arguments
are embedded in the HTTP error response body.
The lookup uses the first occurrence of the exception type in the list. If
the exception type is not found in the list, the lookup uses `isinstance`
to determine the most specific exception type corresponding to the supplied
exception. This allows users to call this method with exception types that
are not directly listed in the REST_API_EXCEPTIONS list.
Args:
error: Exception to convert.
Returns:
HTTPException with the appropriate status code and error detail.
"""
from fastapi import HTTPException
status_code = 0
matching_exception_type: Optional[Type[Exception]] = None
for exception_type, exc_status_code in REST_API_EXCEPTIONS:
if error.__class__ is exception_type:
# Found an exact match
matching_exception_type = exception_type
status_code = exc_status_code
break
if isinstance(error, exception_type):
# Found a matching exception
if not matching_exception_type:
# This is the first matching exception, so keep it
matching_exception_type = exception_type
status_code = exc_status_code
continue
# This is not the first matching exception, so check if it is more
# specific than the previous matching exception
if issubclass(
exception_type,
matching_exception_type,
):
matching_exception_type = exception_type
status_code = exc_status_code
# When the matching exception is not found in the list, a 500 Internal
# Server Error is returned
status_code = status_code or 500
matching_exception_type = matching_exception_type or RuntimeError
return HTTPException(
status_code=status_code,
detail=error_detail(error, matching_exception_type),
)
jwt
Authentication module for ZenML server.
JWTToken (BaseModel)
pydantic-model
Pydantic object representing a JWT token.
Attributes:
Name | Type | Description |
---|---|---|
user_id |
UUID |
The id of the authenticated User. |
device_id |
Optional[uuid.UUID] |
The id of the authenticated device. |
pipeline_id |
Optional[uuid.UUID] |
The id of the pipeline for which the token was issued. |
schedule_id |
Optional[uuid.UUID] |
The id of the schedule for which the token was issued. |
permissions |
List[str] |
The permissions scope of the authenticated user. |
claims |
Dict[str, Any] |
The original token claims. |
Source code in zenml/zen_server/jwt.py
class JWTToken(BaseModel):
"""Pydantic object representing a JWT token.
Attributes:
user_id: The id of the authenticated User.
device_id: The id of the authenticated device.
pipeline_id: The id of the pipeline for which the token was issued.
schedule_id: The id of the schedule for which the token was issued.
permissions: The permissions scope of the authenticated user.
claims: The original token claims.
"""
user_id: UUID
permissions: List[str]
device_id: Optional[UUID] = None
pipeline_id: Optional[UUID] = None
schedule_id: Optional[UUID] = None
claims: Dict[str, Any] = {}
@classmethod
def decode_token(
cls,
token: str,
verify: bool = True,
) -> "JWTToken":
"""Decodes a JWT access token.
Decodes a JWT access token and returns a `JWTToken` object with the
information retrieved from its subject claim.
Args:
token: The encoded JWT token.
verify: Whether to verify the signature of the token.
Returns:
The decoded JWT access token.
Raises:
AuthorizationException: If the token is invalid.
"""
config = server_config()
try:
claims_data = jwt.decode(
token,
config.jwt_secret_key,
algorithms=[config.jwt_token_algorithm],
audience=config.get_jwt_token_audience(),
issuer=config.get_jwt_token_issuer(),
verify=verify,
leeway=timedelta(seconds=config.jwt_token_leeway_seconds),
)
claims = cast(Dict[str, Any], claims_data)
except (
jwt.PyJWTError,
jwt.exceptions.DecodeError,
jwt.exceptions.PyJWKClientError,
) as e:
raise AuthorizationException(f"Invalid JWT token: {e}") from e
subject: str = claims.get("sub", "")
if not subject:
raise AuthorizationException(
"Invalid JWT token: the subject claim is missing"
)
try:
user_id = UUID(subject)
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the subject claim is not a valid UUID"
)
device_id: Optional[UUID] = None
if "device_id" in claims:
try:
device_id = UUID(claims["device_id"])
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the device_id claim is not a valid "
"UUID"
)
pipeline_id: Optional[UUID] = None
if "pipeline_id" in claims:
try:
pipeline_id = UUID(claims["pipeline_id"])
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the pipeline_id claim is not a valid "
"UUID"
)
schedule_id: Optional[UUID] = None
if "schedule_id" in claims:
try:
schedule_id = UUID(claims["schedule_id"])
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the schedule_id claim is not a valid "
"UUID"
)
permissions: List[str] = claims.get("permissions", [])
return JWTToken(
user_id=user_id,
device_id=device_id,
pipeline_id=pipeline_id,
schedule_id=schedule_id,
permissions=list(set(permissions)),
claims=claims,
)
def encode(self, expires: Optional[datetime] = None) -> str:
"""Creates a JWT access token.
Encodes, signs and returns a JWT access token.
Args:
expires: Datetime after which the token will expire. If not
provided, the JWT token will not be set to expire.
Returns:
The generated access token.
"""
config = server_config()
claims: Dict[str, Any] = dict(
sub=str(self.user_id),
permissions=list(self.permissions),
)
claims["iss"] = config.get_jwt_token_issuer()
claims["aud"] = config.get_jwt_token_audience()
if expires:
claims["exp"] = expires
if self.device_id:
claims["device_id"] = str(self.device_id)
if self.pipeline_id:
claims["pipeline_id"] = str(self.pipeline_id)
if self.schedule_id:
claims["schedule_id"] = str(self.schedule_id)
# Apply custom claims
claims.update(self.claims)
return jwt.encode(
claims,
config.jwt_secret_key,
algorithm=config.jwt_token_algorithm,
)
decode_token(token, verify=True)
classmethod
Decodes a JWT access token.
Decodes a JWT access token and returns a JWTToken
object with the
information retrieved from its subject claim.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token |
str |
The encoded JWT token. |
required |
verify |
bool |
Whether to verify the signature of the token. |
True |
Returns:
Type | Description |
---|---|
JWTToken |
The decoded JWT access token. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the token is invalid. |
Source code in zenml/zen_server/jwt.py
@classmethod
def decode_token(
cls,
token: str,
verify: bool = True,
) -> "JWTToken":
"""Decodes a JWT access token.
Decodes a JWT access token and returns a `JWTToken` object with the
information retrieved from its subject claim.
Args:
token: The encoded JWT token.
verify: Whether to verify the signature of the token.
Returns:
The decoded JWT access token.
Raises:
AuthorizationException: If the token is invalid.
"""
config = server_config()
try:
claims_data = jwt.decode(
token,
config.jwt_secret_key,
algorithms=[config.jwt_token_algorithm],
audience=config.get_jwt_token_audience(),
issuer=config.get_jwt_token_issuer(),
verify=verify,
leeway=timedelta(seconds=config.jwt_token_leeway_seconds),
)
claims = cast(Dict[str, Any], claims_data)
except (
jwt.PyJWTError,
jwt.exceptions.DecodeError,
jwt.exceptions.PyJWKClientError,
) as e:
raise AuthorizationException(f"Invalid JWT token: {e}") from e
subject: str = claims.get("sub", "")
if not subject:
raise AuthorizationException(
"Invalid JWT token: the subject claim is missing"
)
try:
user_id = UUID(subject)
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the subject claim is not a valid UUID"
)
device_id: Optional[UUID] = None
if "device_id" in claims:
try:
device_id = UUID(claims["device_id"])
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the device_id claim is not a valid "
"UUID"
)
pipeline_id: Optional[UUID] = None
if "pipeline_id" in claims:
try:
pipeline_id = UUID(claims["pipeline_id"])
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the pipeline_id claim is not a valid "
"UUID"
)
schedule_id: Optional[UUID] = None
if "schedule_id" in claims:
try:
schedule_id = UUID(claims["schedule_id"])
except ValueError:
raise AuthorizationException(
"Invalid JWT token: the schedule_id claim is not a valid "
"UUID"
)
permissions: List[str] = claims.get("permissions", [])
return JWTToken(
user_id=user_id,
device_id=device_id,
pipeline_id=pipeline_id,
schedule_id=schedule_id,
permissions=list(set(permissions)),
claims=claims,
)
encode(self, expires=None)
Creates a JWT access token.
Encodes, signs and returns a JWT access token.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expires |
Optional[datetime.datetime] |
Datetime after which the token will expire. If not provided, the JWT token will not be set to expire. |
None |
Returns:
Type | Description |
---|---|
str |
The generated access token. |
Source code in zenml/zen_server/jwt.py
def encode(self, expires: Optional[datetime] = None) -> str:
"""Creates a JWT access token.
Encodes, signs and returns a JWT access token.
Args:
expires: Datetime after which the token will expire. If not
provided, the JWT token will not be set to expire.
Returns:
The generated access token.
"""
config = server_config()
claims: Dict[str, Any] = dict(
sub=str(self.user_id),
permissions=list(self.permissions),
)
claims["iss"] = config.get_jwt_token_issuer()
claims["aud"] = config.get_jwt_token_audience()
if expires:
claims["exp"] = expires
if self.device_id:
claims["device_id"] = str(self.device_id)
if self.pipeline_id:
claims["pipeline_id"] = str(self.pipeline_id)
if self.schedule_id:
claims["schedule_id"] = str(self.schedule_id)
# Apply custom claims
claims.update(self.claims)
return jwt.encode(
claims,
config.jwt_secret_key,
algorithm=config.jwt_token_algorithm,
)
routers
special
Endpoint definitions.
artifacts_endpoints
Endpoint definitions for steps (and artifacts) of pipeline runs.
create_artifact(artifact, _=Security(oauth2_authentication))
Create a new artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactRequestModel |
The artifact to create. |
required |
Returns:
Type | Description |
---|---|
ArtifactResponseModel |
The created artifact. |
Source code in zenml/zen_server/routers/artifacts_endpoints.py
@router.post(
"",
response_model=ArtifactResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_artifact(
artifact: ArtifactRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> ArtifactResponseModel:
"""Create a new artifact.
Args:
artifact: The artifact to create.
Returns:
The created artifact.
"""
return zen_store().create_artifact(artifact)
delete_artifact(artifact_id, _=Security(oauth2_authentication))
Delete an artifact by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_id |
UUID |
The ID of the artifact to delete. |
required |
Source code in zenml/zen_server/routers/artifacts_endpoints.py
@router.delete(
"/{artifact_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_artifact(
artifact_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Delete an artifact by ID.
Args:
artifact_id: The ID of the artifact to delete.
"""
zen_store().delete_artifact(artifact_id)
get_artifact(artifact_id, _=Security(oauth2_authentication))
Get an artifact by ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_id |
UUID |
The ID of the artifact to get. |
required |
Returns:
Type | Description |
---|---|
ArtifactResponseModel |
The artifact with the given ID. |
Source code in zenml/zen_server/routers/artifacts_endpoints.py
@router.get(
"/{artifact_id}",
response_model=ArtifactResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_artifact(
artifact_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ArtifactResponseModel:
"""Get an artifact by ID.
Args:
artifact_id: The ID of the artifact to get.
Returns:
The artifact with the given ID.
"""
return zen_store().get_artifact(artifact_id)
get_artifact_visualization(artifact_id, index=0, _=Security(oauth2_authentication))
Get the visualization of an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_id |
UUID |
ID of the artifact for which to get the visualization. |
required |
index |
int |
Index of the visualization to get (if there are multiple). |
0 |
Returns:
Type | Description |
---|---|
LoadedVisualizationModel |
The visualization of the artifact. |
Source code in zenml/zen_server/routers/artifacts_endpoints.py
@router.get(
"/{artifact_id}" + VISUALIZE,
response_model=LoadedVisualizationModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_artifact_visualization(
artifact_id: UUID,
index: int = 0,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> LoadedVisualizationModel:
"""Get the visualization of an artifact.
Args:
artifact_id: ID of the artifact for which to get the visualization.
index: Index of the visualization to get (if there are multiple).
Returns:
The visualization of the artifact.
"""
store = zen_store()
artifact = store.get_artifact(artifact_id)
return load_artifact_visualization(
artifact=artifact, index=index, zen_store=store, encode_image=True
)
list_artifacts(artifact_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get artifacts according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_filter_model |
ArtifactFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ArtifactResponseModel] |
The artifacts according to query filters. |
Source code in zenml/zen_server/routers/artifacts_endpoints.py
@router.get(
"",
response_model=Page[ArtifactResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_artifacts(
artifact_filter_model: ArtifactFilterModel = Depends(
make_dependable(ArtifactFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ArtifactResponseModel]:
"""Get artifacts according to query filters.
Args:
artifact_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The artifacts according to query filters.
"""
return zen_store().list_artifacts(
artifact_filter_model=artifact_filter_model
)
auth_endpoints
Endpoint definitions for authentication (login).
OAuthLoginRequestForm
OAuth2 grant type request form.
This form allows multiple grant types to be used with the same endpoint: * standard OAuth2 password grant type * standard OAuth2 device authorization grant type * ZenML External Authenticator grant type
Source code in zenml/zen_server/routers/auth_endpoints.py
class OAuthLoginRequestForm:
"""OAuth2 grant type request form.
This form allows multiple grant types to be used with the same endpoint:
* standard OAuth2 password grant type
* standard OAuth2 device authorization grant type
* ZenML External Authenticator grant type
"""
def __init__(
self,
grant_type: Optional[str] = Form(None),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
client_id: Optional[str] = Form(None),
device_code: Optional[str] = Form(None),
):
"""Initializes the form.
Args:
grant_type: The grant type.
username: The username. Only used for the password grant type.
password: The password. Only used for the password grant type.
client_id: The client ID.
device_code: The device code. Only used for the device authorization
grant type.
Raises:
HTTPException: If the request is invalid.
"""
if not grant_type:
# Detect the grant type from the form data
if username is not None:
self.grant_type = OAuthGrantTypes.OAUTH_PASSWORD
elif device_code:
self.grant_type = OAuthGrantTypes.OAUTH_DEVICE_CODE
else:
self.grant_type = OAuthGrantTypes.ZENML_EXTERNAL
else:
if grant_type not in OAuthGrantTypes.values():
logger.info(
f"Request with unsupported grant type: {grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {grant_type}",
)
self.grant_type = OAuthGrantTypes(grant_type)
config = server_config()
if self.grant_type == OAuthGrantTypes.OAUTH_PASSWORD:
if config.auth_scheme != AuthScheme.OAUTH2_PASSWORD_BEARER:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
if not username:
logger.info("Request with missing username")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: username is required.",
)
self.username = username
self.password = password or ""
elif self.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE:
if not device_code or not client_id:
logger.info("Request with missing device code or client ID")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: device code and client ID are "
"required.",
)
try:
self.client_id = UUID(client_id)
except ValueError:
logger.info(f"Request with invalid client ID: {client_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: invalid client ID.",
)
self.device_code = device_code
elif self.grant_type == OAuthGrantTypes.ZENML_EXTERNAL:
if config.auth_scheme != AuthScheme.EXTERNAL:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
__init__(self, grant_type=Form(None), username=Form(None), password=Form(None), client_id=Form(None), device_code=Form(None))
special
Initializes the form.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
grant_type |
Optional[str] |
The grant type. |
Form(None) |
username |
Optional[str] |
The username. Only used for the password grant type. |
Form(None) |
password |
Optional[str] |
The password. Only used for the password grant type. |
Form(None) |
client_id |
Optional[str] |
The client ID. |
Form(None) |
device_code |
Optional[str] |
The device code. Only used for the device authorization grant type. |
Form(None) |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the request is invalid. |
Source code in zenml/zen_server/routers/auth_endpoints.py
def __init__(
self,
grant_type: Optional[str] = Form(None),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
client_id: Optional[str] = Form(None),
device_code: Optional[str] = Form(None),
):
"""Initializes the form.
Args:
grant_type: The grant type.
username: The username. Only used for the password grant type.
password: The password. Only used for the password grant type.
client_id: The client ID.
device_code: The device code. Only used for the device authorization
grant type.
Raises:
HTTPException: If the request is invalid.
"""
if not grant_type:
# Detect the grant type from the form data
if username is not None:
self.grant_type = OAuthGrantTypes.OAUTH_PASSWORD
elif device_code:
self.grant_type = OAuthGrantTypes.OAUTH_DEVICE_CODE
else:
self.grant_type = OAuthGrantTypes.ZENML_EXTERNAL
else:
if grant_type not in OAuthGrantTypes.values():
logger.info(
f"Request with unsupported grant type: {grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {grant_type}",
)
self.grant_type = OAuthGrantTypes(grant_type)
config = server_config()
if self.grant_type == OAuthGrantTypes.OAUTH_PASSWORD:
if config.auth_scheme != AuthScheme.OAUTH2_PASSWORD_BEARER:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
if not username:
logger.info("Request with missing username")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: username is required.",
)
self.username = username
self.password = password or ""
elif self.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE:
if not device_code or not client_id:
logger.info("Request with missing device code or client ID")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: device code and client ID are "
"required.",
)
try:
self.client_id = UUID(client_id)
except ValueError:
logger.info(f"Request with invalid client ID: {client_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request: invalid client ID.",
)
self.device_code = device_code
elif self.grant_type == OAuthGrantTypes.ZENML_EXTERNAL:
if config.auth_scheme != AuthScheme.EXTERNAL:
logger.info(
f"Request with unsupported grant type: {self.grant_type}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported grant type: {self.grant_type}.",
)
api_token(pipeline_id=None, schedule_id=None, expires_minutes=None, auth_context=Security(oauth2_authentication))
Get a workload API token for the current user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
Optional[uuid.UUID] |
The ID of the pipeline to get the API token for. |
None |
schedule_id |
Optional[uuid.UUID] |
The ID of the schedule to get the API token for. |
None |
expires_minutes |
Optional[int] |
The number of minutes for which the API token should be valid. If not provided, the API token will be valid indefinitely. |
None |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
str |
The API token. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the user is not authenticated. |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.get(
API_TOKEN,
response_model=str,
)
@handle_exceptions
def api_token(
pipeline_id: Optional[UUID] = None,
schedule_id: Optional[UUID] = None,
expires_minutes: Optional[int] = None,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> str:
"""Get a workload API token for the current user.
Args:
pipeline_id: The ID of the pipeline to get the API token for.
schedule_id: The ID of the schedule to get the API token for.
expires_minutes: The number of minutes for which the API token should
be valid. If not provided, the API token will be valid indefinitely.
auth_context: The authentication context.
Returns:
The API token.
Raises:
HTTPException: If the user is not authenticated.
"""
token = auth_context.access_token
if not token or not auth_context.encoded_access_token:
# Should not happen
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated.",
)
if not token.device_id:
# If not authenticated with a device, the current API token is returned
# as is, without any modifications. Issuing workload tokens is only
# supported for device authenticated users, because device tokens can
# be revoked at any time.
return auth_context.encoded_access_token
# If authenticated with a device, a new API token is generated for the
# pipeline and/or schedule.
if pipeline_id:
token.pipeline_id = pipeline_id
if schedule_id:
token.schedule_id = schedule_id
expires: Optional[datetime] = None
if expires_minutes:
expires = datetime.utcnow() + timedelta(minutes=expires_minutes)
return token.encode(expires=expires)
device_authorization(request, client_id=Form(Ellipsis))
OAuth2 device authorization endpoint.
This endpoint implements the OAuth2 device authorization grant flow as defined in https://tools.ietf.org/html/rfc8628. It is called to initiate the device authorization flow by requesting a device and user code for a given client ID.
For a new client ID, a new OAuth device is created, stored in the DB and returned to the client along with a pair of newly generated device and user codes. If a device for the given client ID already exists, the existing DB entry is reused and new device and user codes are generated.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request object. |
required |
client_id |
UUID |
The client ID. |
Form(Ellipsis) |
Returns:
Type | Description |
---|---|
OAuthDeviceAuthorizationResponse |
The device authorization response. |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.post(
DEVICE_AUTHORIZATION,
response_model=OAuthDeviceAuthorizationResponse,
)
def device_authorization(
request: Request,
client_id: UUID = Form(...),
) -> OAuthDeviceAuthorizationResponse:
"""OAuth2 device authorization endpoint.
This endpoint implements the OAuth2 device authorization grant flow as
defined in https://tools.ietf.org/html/rfc8628. It is called to initiate
the device authorization flow by requesting a device and user code for a
given client ID.
For a new client ID, a new OAuth device is created, stored in the DB and
returned to the client along with a pair of newly generated device and user
codes. If a device for the given client ID already exists, the existing
DB entry is reused and new device and user codes are generated.
Args:
request: The request object.
client_id: The client ID.
Returns:
The device authorization response.
"""
config = server_config()
store = zen_store()
# Use this opportunity to delete expired devices
store.delete_expired_authorized_devices()
# Fetch additional details about the client from the user-agent header
user_agent_header = request.headers.get("User-Agent")
if user_agent_header:
device_details = OAuthDeviceUserAgentHeader.decode(user_agent_header)
else:
device_details = OAuthDeviceUserAgentHeader()
# Fetch the IP address of the client
ip_address: str = ""
city, region, country = "", "", ""
if request.client and request.client.host:
ip_address = request.client.host
city, region, country = get_ip_location(ip_address)
# Check if a device is already registered for the same client ID.
try:
device_model = store.get_internal_authorized_device(
client_id=client_id
)
except KeyError:
device_model = store.create_authorized_device(
OAuthDeviceInternalRequestModel(
client_id=client_id,
expires_in=config.device_auth_timeout,
ip_address=ip_address,
city=city,
region=region,
country=country,
**device_details.dict(exclude_none=True),
)
)
else:
# Put the device into pending state and generate new codes. This
# effectively invalidates the old codes and the device cannot be used
# for authentication anymore.
device_model = store.update_internal_authorized_device(
device_id=device_model.id,
update=OAuthDeviceInternalUpdateModel(
trusted_device=False,
expires_in=config.device_auth_timeout,
status=OAuthDeviceStatus.PENDING,
failed_auth_attempts=0,
generate_new_codes=True,
ip_address=ip_address,
city=city,
region=region,
country=country,
**device_details.dict(exclude_none=True),
),
)
if config.dashboard_url:
verification_uri = (
config.dashboard_url.lstrip("/") + DEVICES + DEVICE_VERIFY
)
else:
verification_uri = DEVICES + DEVICE_VERIFY
verification_uri_complete = (
verification_uri
+ "?"
+ urlencode(
dict(
device_id=str(device_model.id),
user_code=str(device_model.user_code),
)
)
)
return OAuthDeviceAuthorizationResponse(
device_code=device_model.device_code,
user_code=device_model.user_code,
expires_in=config.device_auth_timeout,
interval=config.device_auth_polling_interval,
verification_uri=verification_uri,
verification_uri_complete=verification_uri_complete,
)
generate_access_token(user_id, response, device=None)
Generates an access token for the given user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_id |
UUID |
The ID of the user. |
required |
response |
Response |
The FastAPI response object. |
required |
device |
Optional[zenml.models.device_models.OAuthDeviceInternalResponseModel] |
The device used for authentication. |
None |
Returns:
Type | Description |
---|---|
OAuthTokenResponse |
An authentication response with an access token. |
Source code in zenml/zen_server/routers/auth_endpoints.py
def generate_access_token(
user_id: UUID,
response: Response,
device: Optional[OAuthDeviceInternalResponseModel] = None,
) -> OAuthTokenResponse:
"""Generates an access token for the given user.
Args:
user_id: The ID of the user.
response: The FastAPI response object.
device: The device used for authentication.
Returns:
An authentication response with an access token.
"""
role_assignments = zen_store().list_user_role_assignments(
user_role_assignment_filter_model=UserRoleAssignmentFilterModel(
user_id=user_id
)
)
# TODO: This needs to happen at the sql level now
permissions = set().union(
*[
zen_store().get_role(ra.role.id).permissions
for ra in role_assignments.items
if ra.role is not None
]
)
config = server_config()
# The JWT tokens are set to expire according to the values configured
# in the server config. Device tokens are handled separately from regular
# user tokens.
expires: Optional[datetime] = None
expires_in: Optional[int] = None
if device:
# If a device was used for authentication, the token will expire
# at the same time as the device.
expires = device.expires
if expires:
expires_in = max(
int(expires.timestamp() - datetime.utcnow().timestamp()), 0
)
elif config.jwt_token_expire_minutes:
expires = datetime.utcnow() + timedelta(
minutes=config.jwt_token_expire_minutes
)
expires_in = config.jwt_token_expire_minutes * 60
access_token = JWTToken(
user_id=user_id,
device_id=device.id if device else None,
permissions=[p.value for p in permissions],
).encode(expires=expires)
if not device:
# Also set the access token as an HTTP only cookie in the response
response.set_cookie(
key=config.get_auth_cookie_name(),
value=access_token,
httponly=True,
samesite="lax",
max_age=config.jwt_token_expire_minutes * 60
if config.jwt_token_expire_minutes
else None,
domain=config.auth_cookie_domain,
)
return OAuthTokenResponse(
access_token=access_token, expires_in=expires_in, token_type="bearer"
)
logout(response)
Logs out the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
response |
Response |
The response object. |
required |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.get(
LOGOUT,
)
def logout(
response: Response,
) -> None:
"""Logs out the user.
Args:
response: The response object.
"""
config = server_config()
# Remove the HTTP only cookie even if it does not exist
response.delete_cookie(
key=config.get_auth_cookie_name(),
httponly=True,
samesite="lax",
domain=config.auth_cookie_domain,
)
token(request, response, auth_form_data=Depends(OAuthLoginRequestForm))
OAuth2 token endpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
The request object. |
required |
response |
Response |
The response object. |
required |
auth_form_data |
OAuthLoginRequestForm |
The OAuth 2.0 authentication form data. |
Depends(OAuthLoginRequestForm) |
Returns:
Type | Description |
---|---|
Union[zenml.models.auth_models.OAuthTokenResponse, zenml.models.auth_models.OAuthRedirectResponse] |
An access token or a redirect response. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the grant type is invalid. |
Source code in zenml/zen_server/routers/auth_endpoints.py
@router.post(
LOGIN,
response_model=Union[OAuthTokenResponse, OAuthRedirectResponse],
)
@handle_exceptions
def token(
request: Request,
response: Response,
auth_form_data: OAuthLoginRequestForm = Depends(),
) -> Union[OAuthTokenResponse, OAuthRedirectResponse]:
"""OAuth2 token endpoint.
Args:
request: The request object.
response: The response object.
auth_form_data: The OAuth 2.0 authentication form data.
Returns:
An access token or a redirect response.
Raises:
ValueError: If the grant type is invalid.
"""
if auth_form_data.grant_type == OAuthGrantTypes.OAUTH_PASSWORD:
auth_context = authenticate_credentials(
user_name_or_id=auth_form_data.username,
password=auth_form_data.password,
)
elif auth_form_data.grant_type == OAuthGrantTypes.OAUTH_DEVICE_CODE:
auth_context = authenticate_device(
client_id=auth_form_data.client_id,
device_code=auth_form_data.device_code,
)
elif auth_form_data.grant_type == OAuthGrantTypes.ZENML_EXTERNAL:
config = server_config()
assert config.external_cookie_name is not None
assert config.external_login_url is not None
authorization_url = config.external_login_url
# First, try to get the external access token from the external cookie
external_access_token = request.cookies.get(
config.external_cookie_name
)
if not external_access_token:
# Next, try to get the external access token from the authorization
# header
authorization_header = request.headers.get("Authorization")
if authorization_header:
scheme, _, token = authorization_header.partition(" ")
if token and scheme.lower() == "bearer":
external_access_token = token
logger.info(
"External access token found in authorization header."
)
else:
logger.info("External access token found in cookie.")
if not external_access_token:
logger.info(
"External access token not found. Redirecting to "
"external authenticator."
)
# Redirect the user to the external authentication login endpoint
return OAuthRedirectResponse(authorization_url=authorization_url)
auth_context = authenticate_external_user(
external_access_token=external_access_token
)
else:
# Shouldn't happen, because we verify all grants in the form data
raise ValueError("Invalid grant type.")
return generate_access_token(
user_id=auth_context.user.id,
response=response,
device=auth_context.device,
)
code_repositories_endpoints
Endpoint definitions for code repositories.
delete_code_repository(code_repository_id, _=Security(oauth2_authentication))
Deletes a specific code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository to delete. |
required |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.delete(
"/{code_repository_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_code_repository(
code_repository_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific code repository.
Args:
code_repository_id: The ID of the code repository to delete.
"""
zen_store().delete_code_repository(code_repository_id=code_repository_id)
get_code_repository(code_repository_id, _=Security(oauth2_authentication))
Gets a specific code repository using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository to get. |
required |
Returns:
Type | Description |
---|---|
CodeRepositoryResponseModel |
A specific code repository object. |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.get(
"/{code_repository_id}",
response_model=CodeRepositoryResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_code_repository(
code_repository_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> CodeRepositoryResponseModel:
"""Gets a specific code repository using its unique ID.
Args:
code_repository_id: The ID of the code repository to get.
Returns:
A specific code repository object.
"""
return zen_store().get_code_repository(
code_repository_id=code_repository_id
)
list_code_repositories(filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets a page of code repositories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
CodeRepositoryFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[CodeRepositoryResponseModel] |
Page of code repository objects. |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.get(
"",
response_model=Page[CodeRepositoryResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_code_repositories(
filter_model: CodeRepositoryFilterModel = Depends(
make_dependable(CodeRepositoryFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[CodeRepositoryResponseModel]:
"""Gets a page of code repositories.
Args:
filter_model: Filter model used for pagination, sorting,
filtering
Returns:
Page of code repository objects.
"""
return zen_store().list_code_repositories(filter_model=filter_model)
update_code_repository(code_repository_id, update, _=Security(oauth2_authentication))
Updates a code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository to update. |
required |
update |
CodeRepositoryUpdateModel |
The model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
CodeRepositoryResponseModel |
The updated code repository object. |
Source code in zenml/zen_server/routers/code_repositories_endpoints.py
@router.put(
"/{code_repository_id}",
response_model=CodeRepositoryResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_code_repository(
code_repository_id: UUID,
update: CodeRepositoryUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> CodeRepositoryResponseModel:
"""Updates a code repository.
Args:
code_repository_id: The ID of the code repository to update.
update: The model containing the attributes to update.
Returns:
The updated code repository object.
"""
return zen_store().update_code_repository(
code_repository_id=code_repository_id, update=update
)
devices_endpoints
Endpoint definitions for code repositories.
delete_authorized_device(device_id, auth_context=Security(oauth2_authentication))
Deletes a specific OAuth2 authorized device using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to delete. |
required |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Exceptions:
Type | Description |
---|---|
KeyError |
If the device with the given ID does not exist or does not belong to the current user. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.delete(
"/{device_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_authorized_device(
device_id: UUID,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> None:
"""Deletes a specific OAuth2 authorized device using its unique ID.
Args:
device_id: The ID of the OAuth2 authorized device to delete.
auth_context: The current auth context.
Raises:
KeyError: If the device with the given ID does not exist or does not
belong to the current user.
"""
device = zen_store().get_authorized_device(device_id=device_id)
if not device.user or device.user.id != auth_context.user.id:
raise KeyError(
f"Unable to get device with ID {device_id}: No device with "
"this ID found."
)
zen_store().delete_authorized_device(device_id=device_id)
get_authorization_device(device_id, user_code=None, auth_context=Security(oauth2_authentication))
Gets a specific OAuth2 authorized device using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to get. |
required |
user_code |
Optional[str] |
The user code of the OAuth2 authorized device to get. Needs to be specified with devices that have not been verified yet. |
None |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
OAuthDeviceResponseModel |
A specific OAuth2 authorized device object. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the device with the given ID does not exist, does not belong to the current user or could not be verified using the given user code. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.get(
"/{device_id}",
response_model=OAuthDeviceResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_authorization_device(
device_id: UUID,
user_code: Optional[str] = None,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> OAuthDeviceResponseModel:
"""Gets a specific OAuth2 authorized device using its unique ID.
Args:
device_id: The ID of the OAuth2 authorized device to get.
user_code: The user code of the OAuth2 authorized device to get. Needs
to be specified with devices that have not been verified yet.
auth_context: The current auth context.
Returns:
A specific OAuth2 authorized device object.
Raises:
KeyError: If the device with the given ID does not exist, does not
belong to the current user or could not be verified using the
given user code.
"""
device = zen_store().get_authorized_device(device_id=device_id)
if not device.user:
# A device that hasn't been verified and associated with a user yet
# can only be retrieved if the user code is specified and valid.
if user_code:
internal_device = zen_store().get_internal_authorized_device(
device_id=device_id
)
if internal_device.verify_user_code(user_code=user_code):
return device
elif device.user.id == auth_context.user.id:
return device
raise KeyError(
f"Unable to get device with ID {device_id}: No device with "
"this ID found."
)
list_authorized_devices(filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
Gets a page of OAuth2 authorized devices belonging to the current user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filter_model |
OAuthDeviceFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[OAuthDeviceResponseModel] |
Page of OAuth2 authorized device objects. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.get(
"",
response_model=Page[OAuthDeviceResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_authorized_devices(
filter_model: OAuthDeviceFilterModel = Depends(
make_dependable(OAuthDeviceFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[OAuthDeviceResponseModel]:
"""Gets a page of OAuth2 authorized devices belonging to the current user.
Args:
filter_model: Filter model used for pagination, sorting,
filtering
auth_context: The current auth context.
Returns:
Page of OAuth2 authorized device objects.
"""
filter_model.set_scope_user(auth_context.user.id)
return zen_store().list_authorized_devices(filter_model=filter_model)
update_authorized_device(device_id, update, auth_context=Security(oauth2_authentication))
Updates a specific OAuth2 authorized device using its unique ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to update. |
required |
update |
OAuthDeviceUpdateModel |
The model containing the attributes to update. |
required |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
OAuthDeviceResponseModel |
The updated OAuth2 authorized device object. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the device with the given ID does not exist or does not belong to the current user. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.put(
"/{device_id}",
response_model=OAuthDeviceResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_authorized_device(
device_id: UUID,
update: OAuthDeviceUpdateModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> OAuthDeviceResponseModel:
"""Updates a specific OAuth2 authorized device using its unique ID.
Args:
device_id: The ID of the OAuth2 authorized device to update.
update: The model containing the attributes to update.
auth_context: The current auth context.
Returns:
The updated OAuth2 authorized device object.
Raises:
KeyError: If the device with the given ID does not exist or does not
belong to the current user.
"""
device = zen_store().get_authorized_device(device_id=device_id)
if not device.user or device.user.id != auth_context.user.id:
raise KeyError(
f"Unable to get device with ID {device_id}: No device with "
"this ID found."
)
return zen_store().update_authorized_device(
device_id=device_id, update=update
)
verify_authorized_device(device_id, request, auth_context=Security(oauth2_authentication))
Verifies a specific OAuth2 authorized device using its unique ID.
This endpoint implements the OAuth2 device authorization grant flow as defined in https://tools.ietf.org/html/rfc8628. It is called to verify the user code for a given device ID.
If the user code is valid, the device is marked as verified and associated with the user that authorized the device. This association is required to be able to issue access tokens or revoke the device later on.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device_id |
UUID |
The ID of the OAuth2 authorized device to update. |
required |
request |
OAuthDeviceVerificationRequest |
The model containing the verification request. |
required |
auth_context |
AuthContext |
The current auth context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
OAuthDeviceResponseModel |
The updated OAuth2 authorized device object. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the device verification request fails. |
Source code in zenml/zen_server/routers/devices_endpoints.py
@router.put(
"/{device_id}" + DEVICE_VERIFY,
response_model=OAuthDeviceResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def verify_authorized_device(
device_id: UUID,
request: OAuthDeviceVerificationRequest,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> OAuthDeviceResponseModel:
"""Verifies a specific OAuth2 authorized device using its unique ID.
This endpoint implements the OAuth2 device authorization grant flow as
defined in https://tools.ietf.org/html/rfc8628. It is called to verify
the user code for a given device ID.
If the user code is valid, the device is marked as verified and associated
with the user that authorized the device. This association is required to
be able to issue access tokens or revoke the device later on.
Args:
device_id: The ID of the OAuth2 authorized device to update.
request: The model containing the verification request.
auth_context: The current auth context.
Returns:
The updated OAuth2 authorized device object.
Raises:
ValueError: If the device verification request fails.
"""
config = server_config()
store = zen_store()
# Check if a device is registered for the ID
device_model = store.get_internal_authorized_device(
device_id=device_id,
)
# Check if the device is in a state that allows verification.
if device_model.status != OAuthDeviceStatus.PENDING:
raise ValueError(
"Invalid request: device not pending verification.",
)
# Check if the device verification has expired.
if device_model.expires and device_model.expires < datetime.utcnow():
raise ValueError(
"Invalid request: device verification expired.",
)
# Check if the device already has a user associated with it. If so, the
# current user and the user associated with the device must be the same.
if device_model.user and device_model.user.id != auth_context.user.id:
raise ValueError(
"Invalid request: this device is associated with another user.",
)
# Check if the device code is valid.
if not device_model.verify_user_code(request.user_code):
# If the device code is invalid, increment the failed auth attempts
# counter and lock the device if the maximum number of failed auth
# attempts has been reached.
failed_auth_attempts = device_model.failed_auth_attempts + 1
update = OAuthDeviceInternalUpdateModel(
failed_auth_attempts=failed_auth_attempts
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
update.locked = True
store.update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
if failed_auth_attempts >= config.max_failed_device_auth_attempts:
raise ValueError(
"Invalid request: device locked due to too many failed "
"authentication attempts.",
)
raise ValueError(
"Invalid request: invalid user code.",
)
# If the device code is valid, associate the device with the current user.
# We don't reset the expiration date yet, because we want to make sure
# that the client has received the access token before we do so, to avoid
# brute force attacks on the device code.
update = OAuthDeviceInternalUpdateModel(
status=OAuthDeviceStatus.VERIFIED,
user_id=auth_context.user.id,
failed_auth_attempts=0,
trusted_device=request.trusted_device,
)
device_model = store.update_internal_authorized_device(
device_id=device_model.id,
update=update,
)
return device_model
flavors_endpoints
Endpoint definitions for flavors.
create_flavor(flavor, auth_context=Security(oauth2_authentication))
Creates a stack component flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
FlavorRequestModel |
Stack component flavor to register. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
FlavorResponseModel |
The created stack component flavor. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the stack component flavor does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.post(
"",
response_model=FlavorResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_flavor(
flavor: FlavorRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> FlavorResponseModel:
"""Creates a stack component flavor.
Args:
flavor: Stack component flavor to register.
auth_context: Authentication context.
Returns:
The created stack component flavor.
Raises:
IllegalOperationError: If the workspace or user specified in the stack
component flavor does not match the current workspace or authenticated
user.
"""
if flavor.user != auth_context.user.id:
raise IllegalOperationError(
"Creating flavors for a user other than yourself "
"is not supported."
)
created_flavor = zen_store().create_flavor(
flavor=flavor,
)
return created_flavor
delete_flavor(flavor_id, _=Security(oauth2_authentication))
Deletes a flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_id |
UUID |
ID of the flavor. |
required |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.delete(
"/{flavor_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_flavor(
flavor_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a flavor.
Args:
flavor_id: ID of the flavor.
"""
zen_store().delete_flavor(flavor_id)
get_flavor(flavor_id, _=Security(oauth2_authentication))
Returns the requested flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_id |
UUID |
ID of the flavor. |
required |
Returns:
Type | Description |
---|---|
FlavorResponseModel |
The requested stack. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.get(
"/{flavor_id}",
response_model=FlavorResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_flavor(
flavor_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> FlavorResponseModel:
"""Returns the requested flavor.
Args:
flavor_id: ID of the flavor.
Returns:
The requested stack.
"""
flavor = zen_store().get_flavor(flavor_id)
return flavor
list_flavors(flavor_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns all flavors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_filter_model |
FlavorFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[FlavorResponseModel] |
All flavors. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.get(
"",
response_model=Page[FlavorResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_flavors(
flavor_filter_model: FlavorFilterModel = Depends(
make_dependable(FlavorFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[FlavorResponseModel]:
"""Returns all flavors.
Args:
flavor_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
All flavors.
"""
return zen_store().list_flavors(flavor_filter_model=flavor_filter_model)
sync_flavors(_=Security(oauth2_authentication))
Purge all in-built and integration flavors from the DB and sync.
Returns:
Type | Description |
---|---|
None |
None if successful. Raises an exception otherwise. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.patch(
"/sync",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def sync_flavors(
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Purge all in-built and integration flavors from the DB and sync.
Returns:
None if successful. Raises an exception otherwise.
"""
return zen_store()._sync_flavors()
update_flavor(flavor_id, flavor_update, _=Security(oauth2_authentication))
Updates a flavor.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_id |
UUID |
ID of the team to update. |
required |
flavor_update |
FlavorUpdateModel |
Team update. |
required |
Returns:
Type | Description |
---|---|
FlavorResponseModel |
The updated flavor. |
Source code in zenml/zen_server/routers/flavors_endpoints.py
@router.put(
"/{team_id}",
response_model=FlavorResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def update_flavor(
flavor_id: UUID,
flavor_update: FlavorUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> FlavorResponseModel:
"""Updates a flavor.
# noqa: DAR401
Args:
flavor_id: ID of the team to update.
flavor_update: Team update.
Returns:
The updated flavor.
"""
return zen_store().update_flavor(
flavor_id=flavor_id, flavor_update=flavor_update
)
models_endpoints
Endpoint definitions for models.
delete_model(model_name_or_id, _=Security(oauth2_authentication))
Delete a model by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model to delete. |
required |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.delete(
"/{model_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model(
model_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Delete a model by name or ID.
Args:
model_name_or_id: The name or ID of the model to delete.
"""
zen_store().delete_model(model_name_or_id)
delete_model_version(model_name_or_id, model_version_name_or_id, _=Security(oauth2_authentication))
Delete a model by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model containing version. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model version to delete. |
required |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.delete(
"/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model_version(
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Delete a model by name or ID.
Args:
model_name_or_id: The name or ID of the model containing version.
model_version_name_or_id: The name or ID of the model version to delete.
"""
zen_store().delete_model_version(
model_name_or_id, model_version_name_or_id
)
delete_model_version_artifact_link(model_name_or_id, model_version_name_or_id, model_version_artifact_link_name_or_id, _=Security(oauth2_authentication))
Deletes a model version link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model containing the model version. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model version containing the link. |
required |
model_version_artifact_link_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model version to artifact link to be deleted. |
required |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.delete(
"/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ ARTIFACTS
+ "/{model_version_artifact_link_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model_version_artifact_link(
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_artifact_link_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a model version link.
Args:
model_name_or_id: name or ID of the model containing the model version.
model_version_name_or_id: name or ID of the model version containing the link.
model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted.
"""
zen_store().delete_model_version_artifact_link(
model_name_or_id,
model_version_name_or_id,
model_version_artifact_link_name_or_id,
)
delete_model_version_pipeline_run_link(model_name_or_id, model_version_name_or_id, model_version_pipeline_run_link_name_or_id, _=Security(oauth2_authentication))
Deletes a model version link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model containing the model version. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model version containing the link. |
required |
model_version_pipeline_run_link_name_or_id |
Union[str, uuid.UUID] |
name or ID of the model version link to be deleted. |
required |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.delete(
"/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ RUNS
+ "/{model_version_pipeline_run_link_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_model_version_pipeline_run_link(
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a model version link.
Args:
model_name_or_id: name or ID of the model containing the model version.
model_version_name_or_id: name or ID of the model version containing the link.
model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted.
"""
zen_store().delete_model_version_pipeline_run_link(
model_name_or_id,
model_version_name_or_id,
model_version_pipeline_run_link_name_or_id,
)
get_model(model_name_or_id, _=Security(oauth2_authentication))
Get a model by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model to get. |
required |
Returns:
Type | Description |
---|---|
ModelResponseModel |
The model with the given name or ID. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}",
response_model=ModelResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_model(
model_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ModelResponseModel:
"""Get a model by name or ID.
Args:
model_name_or_id: The name or ID of the model to get.
Returns:
The model with the given name or ID.
"""
return zen_store().get_model(model_name_or_id)
get_model_version(model_name_or_id, model_version_name_or_number_or_id='__latest__', is_number=False, _=Security(oauth2_authentication))
Get a model version by name or ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the model containing version. |
required |
model_version_name_or_number_or_id |
Union[str, int, uuid.UUID, zenml.enums.ModelStages] |
name, id, stage or number of the model version to be retrieved. If skipped latest version will be retrieved. |
'__latest__' |
is_number |
bool |
If the model_version_name_or_number_or_id is a version number |
False |
Returns:
Type | Description |
---|---|
ModelVersionResponseModel |
The model version with the given name or ID. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_number_or_id}",
response_model=ModelVersionResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_model_version(
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Union[
str, int, UUID, ModelStages
] = LATEST_MODEL_VERSION_PLACEHOLDER,
is_number: bool = False,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ModelVersionResponseModel:
"""Get a model version by name or ID.
Args:
model_name_or_id: The name or ID of the model containing version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.
is_number: If the model_version_name_or_number_or_id is a version number
Returns:
The model version with the given name or ID.
"""
return zen_store().get_model_version(
model_name_or_id,
model_version_name_or_number_or_id
if not is_number
else int(model_version_name_or_number_or_id),
)
list_model_version_artifact_links(model_version_artifact_link_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get model version to artifact links according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_artifact_link_filter_model |
ModelVersionArtifactFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelVersionArtifactResponseModel] |
The model version to artifact links according to query filters. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ ARTIFACTS,
response_model=Page[ModelVersionArtifactResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_version_artifact_links(
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends(
make_dependable(ModelVersionArtifactFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links according to query filters.
Args:
model_version_artifact_link_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The model version to artifact links according to query filters.
"""
return zen_store().list_model_version_artifact_links(
model_version_artifact_link_filter_model=model_version_artifact_link_filter_model,
)
list_model_version_pipeline_run_links(model_version_pipeline_run_link_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get model version to pipeline run links according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_pipeline_run_link_filter_model |
ModelVersionPipelineRunFilterModel |
Filter model used for pagination, sorting, and filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelVersionPipelineRunResponseModel] |
The model version to pipeline run links according to query filters. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ RUNS,
response_model=Page[ModelVersionPipelineRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_version_pipeline_run_links(
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends(
make_dependable(ModelVersionPipelineRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get model version to pipeline run links according to query filters.
Args:
model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting,
and filtering
Returns:
The model version to pipeline run links according to query filters.
"""
return zen_store().list_model_version_pipeline_run_links(
model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model,
)
list_model_versions(model_version_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get model versions according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_filter_model |
ModelVersionFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelVersionResponseModel] |
The model versions according to query filters. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"/{model_name_or_id}" + MODEL_VERSIONS,
response_model=Page[ModelVersionResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_model_versions(
model_version_filter_model: ModelVersionFilterModel = Depends(
make_dependable(ModelVersionFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelVersionResponseModel]:
"""Get model versions according to query filters.
Args:
model_version_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The model versions according to query filters.
"""
return zen_store().list_model_versions(
model_version_filter_model=model_version_filter_model,
)
list_models(model_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get models according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_filter_model |
ModelFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelResponseModel] |
The models according to query filters. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.get(
"",
response_model=Page[ModelResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_models(
model_filter_model: ModelFilterModel = Depends(
make_dependable(ModelFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelResponseModel]:
"""Get models according to query filters.
Args:
model_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The models according to query filters.
"""
return zen_store().list_models(
model_filter_model=model_filter_model,
)
update_model(model_id, model_update, _=Security(oauth2_authentication))
Updates a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
UUID |
Name of the stack. |
required |
model_update |
ModelUpdateModel |
Stack to use for the update. |
required |
Returns:
Type | Description |
---|---|
ModelResponseModel |
The updated model. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.put(
"/{model_id}",
response_model=ModelResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_model(
model_id: UUID,
model_update: ModelUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> ModelResponseModel:
"""Updates a model.
Args:
model_id: Name of the stack.
model_update: Stack to use for the update.
Returns:
The updated model.
"""
return zen_store().update_model(
model_id=model_id,
model_update=model_update,
)
update_model_version(model_version_id, model_version_update_model, _=Security(oauth2_authentication))
Get all model versions by filter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version_id |
UUID |
The ID of model version to be updated. |
required |
model_version_update_model |
ModelVersionUpdateModel |
The model version to be updated. |
required |
Returns:
Type | Description |
---|---|
ModelVersionResponseModel |
An updated model version. |
Source code in zenml/zen_server/routers/models_endpoints.py
@router.put(
"/{model_id}" + MODEL_VERSIONS + "/{model_version_id}",
response_model=ModelVersionResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_model_version(
model_version_id: UUID,
model_version_update_model: ModelVersionUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> ModelVersionResponseModel:
"""Get all model versions by filter.
Args:
model_version_id: The ID of model version to be updated.
model_version_update_model: The model version to be updated.
Returns:
An updated model version.
"""
return zen_store().update_model_version(
model_version_id=model_version_id,
model_version_update_model=model_version_update_model,
)
pipeline_builds_endpoints
Endpoint definitions for builds.
delete_build(build_id, _=Security(oauth2_authentication))
Deletes a specific build.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
build_id |
UUID |
ID of the build to delete. |
required |
Source code in zenml/zen_server/routers/pipeline_builds_endpoints.py
@router.delete(
"/{build_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_build(
build_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific build.
Args:
build_id: ID of the build to delete.
"""
zen_store().delete_build(build_id=build_id)
get_build(build_id, _=Security(oauth2_authentication))
Gets a specific build using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
build_id |
UUID |
ID of the build to get. |
required |
Returns:
Type | Description |
---|---|
PipelineBuildResponseModel |
A specific build object. |
Source code in zenml/zen_server/routers/pipeline_builds_endpoints.py
@router.get(
"/{build_id}",
response_model=PipelineBuildResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_build(
build_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> PipelineBuildResponseModel:
"""Gets a specific build using its unique id.
Args:
build_id: ID of the build to get.
Returns:
A specific build object.
"""
return zen_store().get_build(build_id=build_id)
list_builds(build_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets a list of builds.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
build_filter_model |
PipelineBuildFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineBuildResponseModel] |
List of build objects. |
Source code in zenml/zen_server/routers/pipeline_builds_endpoints.py
@router.get(
"",
response_model=Page[PipelineBuildResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_builds(
build_filter_model: PipelineBuildFilterModel = Depends(
make_dependable(PipelineBuildFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineBuildResponseModel]:
"""Gets a list of builds.
Args:
build_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
List of build objects.
"""
return zen_store().list_builds(build_filter_model=build_filter_model)
pipeline_deployments_endpoints
Endpoint definitions for deployments.
delete_deployment(deployment_id, _=Security(oauth2_authentication))
Deletes a specific deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment_id |
UUID |
ID of the deployment to delete. |
required |
Source code in zenml/zen_server/routers/pipeline_deployments_endpoints.py
@router.delete(
"/{deployment_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_deployment(
deployment_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific deployment.
Args:
deployment_id: ID of the deployment to delete.
"""
zen_store().delete_deployment(deployment_id=deployment_id)
get_deployment(deployment_id, _=Security(oauth2_authentication))
Gets a specific deployment using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment_id |
UUID |
ID of the deployment to get. |
required |
Returns:
Type | Description |
---|---|
PipelineDeploymentResponseModel |
A specific deployment object. |
Source code in zenml/zen_server/routers/pipeline_deployments_endpoints.py
@router.get(
"/{deployment_id}",
response_model=PipelineDeploymentResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_deployment(
deployment_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> PipelineDeploymentResponseModel:
"""Gets a specific deployment using its unique id.
Args:
deployment_id: ID of the deployment to get.
Returns:
A specific deployment object.
"""
return zen_store().get_deployment(deployment_id=deployment_id)
list_deployments(deployment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets a list of deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment_filter_model |
PipelineDeploymentFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineDeploymentResponseModel] |
List of deployment objects. |
Source code in zenml/zen_server/routers/pipeline_deployments_endpoints.py
@router.get(
"",
response_model=Page[PipelineDeploymentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_deployments(
deployment_filter_model: PipelineDeploymentFilterModel = Depends(
make_dependable(PipelineDeploymentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineDeploymentResponseModel]:
"""Gets a list of deployment.
Args:
deployment_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
List of deployment objects.
"""
return zen_store().list_deployments(
deployment_filter_model=deployment_filter_model
)
pipelines_endpoints
Endpoint definitions for pipelines.
delete_pipeline(pipeline_id, _=Security(oauth2_authentication))
Deletes a specific pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to delete. |
required |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.delete(
"/{pipeline_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_pipeline(
pipeline_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific pipeline.
Args:
pipeline_id: ID of the pipeline to delete.
"""
zen_store().delete_pipeline(pipeline_id=pipeline_id)
get_pipeline(pipeline_id, _=Security(oauth2_authentication))
Gets a specific pipeline using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to get. |
required |
Returns:
Type | Description |
---|---|
PipelineResponseModel |
A specific pipeline object. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"/{pipeline_id}",
response_model=PipelineResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_pipeline(
pipeline_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> PipelineResponseModel:
"""Gets a specific pipeline using its unique id.
Args:
pipeline_id: ID of the pipeline to get.
Returns:
A specific pipeline object.
"""
return zen_store().get_pipeline(pipeline_id=pipeline_id)
get_pipeline_spec(pipeline_id, _=Security(oauth2_authentication))
Gets the spec of a specific pipeline using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to get. |
required |
Returns:
Type | Description |
---|---|
PipelineSpec |
The spec of the pipeline. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"/{pipeline_id}" + PIPELINE_SPEC,
response_model=PipelineSpec,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_pipeline_spec(
pipeline_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> PipelineSpec:
"""Gets the spec of a specific pipeline using its unique id.
Args:
pipeline_id: ID of the pipeline to get.
Returns:
The spec of the pipeline.
"""
return zen_store().get_pipeline(pipeline_id).spec
list_pipeline_runs(pipeline_run_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get pipeline runs according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run_filter_model |
PipelineRunFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineRunResponseModel] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"/{pipeline_id}" + RUNS,
response_model=Page[PipelineRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_pipeline_runs(
pipeline_run_filter_model: PipelineRunFilterModel = Depends(
make_dependable(PipelineRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineRunResponseModel]:
"""Get pipeline runs according to query filters.
Args:
pipeline_run_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The pipeline runs according to query filters.
"""
return zen_store().list_runs(pipeline_run_filter_model)
list_pipelines(pipeline_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets a list of pipelines.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_filter_model |
PipelineFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineResponseModel] |
List of pipeline objects. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.get(
"",
response_model=Page[PipelineResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_pipelines(
pipeline_filter_model: PipelineFilterModel = Depends(
make_dependable(PipelineFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineResponseModel]:
"""Gets a list of pipelines.
Args:
pipeline_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
List of pipeline objects.
"""
return zen_store().list_pipelines(
pipeline_filter_model=pipeline_filter_model
)
update_pipeline(pipeline_id, pipeline_update, _=Security(oauth2_authentication))
Updates the attribute on a specific pipeline using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_id |
UUID |
ID of the pipeline to get. |
required |
pipeline_update |
PipelineUpdateModel |
the model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
PipelineResponseModel |
The updated pipeline object. |
Source code in zenml/zen_server/routers/pipelines_endpoints.py
@router.put(
"/{pipeline_id}",
response_model=PipelineResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_pipeline(
pipeline_id: UUID,
pipeline_update: PipelineUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> PipelineResponseModel:
"""Updates the attribute on a specific pipeline using its unique id.
Args:
pipeline_id: ID of the pipeline to get.
pipeline_update: the model containing the attributes to update.
Returns:
The updated pipeline object.
"""
return zen_store().update_pipeline(
pipeline_id=pipeline_id, pipeline_update=pipeline_update
)
role_assignments_endpoints
Endpoint definitions for role assignments.
create_role_assignment(role_assignment, _=Security(oauth2_authentication))
Creates a role assignment.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_assignment |
UserRoleAssignmentRequestModel |
Role assignment to create. |
required |
Returns:
Type | Description |
---|---|
UserRoleAssignmentResponseModel |
The created role assignment. |
Source code in zenml/zen_server/routers/role_assignments_endpoints.py
@router.post(
"",
response_model=UserRoleAssignmentResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_role_assignment(
role_assignment: UserRoleAssignmentRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> UserRoleAssignmentResponseModel:
"""Creates a role assignment.
# noqa: DAR401
Args:
role_assignment: Role assignment to create.
Returns:
The created role assignment.
"""
return zen_store().create_user_role_assignment(
user_role_assignment=role_assignment
)
delete_role_assignment(role_assignment_id, _=Security(oauth2_authentication))
Deletes a specific role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_assignment_id |
UUID |
The ID of the role assignment. |
required |
Source code in zenml/zen_server/routers/role_assignments_endpoints.py
@router.delete(
"/{role_assignment_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_role_assignment(
role_assignment_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific role.
Args:
role_assignment_id: The ID of the role assignment.
"""
zen_store().delete_user_role_assignment(
user_role_assignment_id=role_assignment_id
)
get_role_assignment(role_assignment_id, _=Security(oauth2_authentication))
Returns a specific role assignment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_assignment_id |
UUID |
Name or ID of the role assignment. |
required |
Returns:
Type | Description |
---|---|
UserRoleAssignmentResponseModel |
A specific role assignment. |
Source code in zenml/zen_server/routers/role_assignments_endpoints.py
@router.get(
"/{role_assignment_id}",
response_model=UserRoleAssignmentResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_role_assignment(
role_assignment_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> UserRoleAssignmentResponseModel:
"""Returns a specific role assignment.
Args:
role_assignment_id: Name or ID of the role assignment.
Returns:
A specific role assignment.
"""
return zen_store().get_user_role_assignment(
user_role_assignment_id=role_assignment_id
)
list_user_role_assignments(user_role_assignment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all role assignments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_role_assignment_filter_model |
UserRoleAssignmentFilterModel |
filter models for user role assignments |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[UserRoleAssignmentResponseModel] |
List of all role assignments. |
Source code in zenml/zen_server/routers/role_assignments_endpoints.py
@router.get(
"",
response_model=Page[UserRoleAssignmentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_user_role_assignments(
user_role_assignment_filter_model: UserRoleAssignmentFilterModel = Depends(
make_dependable(UserRoleAssignmentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[UserRoleAssignmentResponseModel]:
"""Returns a list of all role assignments.
Args:
user_role_assignment_filter_model: filter models for user role assignments
Returns:
List of all role assignments.
"""
return zen_store().list_user_role_assignments(
user_role_assignment_filter_model=user_role_assignment_filter_model
)
roles_endpoints
Endpoint definitions for roles and role assignment.
create_role(role, _=Security(oauth2_authentication))
Creates a role.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role |
RoleRequestModel |
Role to create. |
required |
Returns:
Type | Description |
---|---|
RoleResponseModel |
The created role. |
Source code in zenml/zen_server/routers/roles_endpoints.py
@router.post(
"",
response_model=RoleResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_role(
role: RoleRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> RoleResponseModel:
"""Creates a role.
# noqa: DAR401
Args:
role: Role to create.
Returns:
The created role.
"""
return zen_store().create_role(role=role)
delete_role(role_name_or_id, _=Security(oauth2_authentication))
Deletes a specific role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the role. |
required |
Source code in zenml/zen_server/routers/roles_endpoints.py
@router.delete(
"/{role_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_role(
role_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific role.
Args:
role_name_or_id: Name or ID of the role.
"""
zen_store().delete_role(role_name_or_id=role_name_or_id)
get_role(role_name_or_id, _=Security(oauth2_authentication))
Returns a specific role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the role. |
required |
Returns:
Type | Description |
---|---|
RoleResponseModel |
A specific role. |
Source code in zenml/zen_server/routers/roles_endpoints.py
@router.get(
"/{role_name_or_id}",
response_model=RoleResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_role(
role_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> RoleResponseModel:
"""Returns a specific role.
Args:
role_name_or_id: Name or ID of the role.
Returns:
A specific role.
"""
return zen_store().get_role(role_name_or_id=role_name_or_id)
list_roles(role_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all roles.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_filter_model |
RoleFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[RoleResponseModel] |
List of all roles. |
Source code in zenml/zen_server/routers/roles_endpoints.py
@router.get(
"",
response_model=Page[RoleResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_roles(
role_filter_model: RoleFilterModel = Depends(
make_dependable(RoleFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[RoleResponseModel]:
"""Returns a list of all roles.
Args:
role_filter_model: Filter model used for pagination, sorting, filtering
Returns:
List of all roles.
"""
return zen_store().list_roles(role_filter_model=role_filter_model)
update_role(role_id, role_update, _=Security(oauth2_authentication))
Updates a role.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_id |
UUID |
The ID of the role. |
required |
role_update |
RoleUpdateModel |
Role update. |
required |
Returns:
Type | Description |
---|---|
RoleResponseModel |
The created role. |
Source code in zenml/zen_server/routers/roles_endpoints.py
@router.put(
"/{role_id}",
response_model=RoleResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def update_role(
role_id: UUID,
role_update: RoleUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> RoleResponseModel:
"""Updates a role.
# noqa: DAR401
Args:
role_id: The ID of the role.
role_update: Role update.
Returns:
The created role.
"""
return zen_store().update_role(role_id=role_id, role_update=role_update)
run_metadata_endpoints
Endpoint definitions for run metadata.
list_run_metadata(run_metadata_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get run metadata according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_metadata_filter_model |
RunMetadataFilterModel |
Filter model used for pagination, sorting, filtering. |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[RunMetadataResponseModel] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/run_metadata_endpoints.py
@router.get(
"",
response_model=Page[RunMetadataResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_run_metadata(
run_metadata_filter_model: RunMetadataFilterModel = Depends(
make_dependable(RunMetadataFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[RunMetadataResponseModel]:
"""Get run metadata according to query filters.
Args:
run_metadata_filter_model: Filter model used for pagination, sorting,
filtering.
Returns:
The pipeline runs according to query filters.
"""
return zen_store().list_run_metadata(run_metadata_filter_model)
runs_endpoints
Endpoint definitions for pipeline runs.
delete_run(run_id, _=Security(oauth2_authentication))
Deletes a run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the run. |
required |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.delete(
"/{run_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_run(
run_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a run.
Args:
run_id: ID of the run.
"""
zen_store().delete_run(run_id=run_id)
get_pipeline_configuration(run_id, _=Security(oauth2_authentication))
Get the pipeline configuration of a specific pipeline run using its ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run to get. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The pipeline configuration of the pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + PIPELINE_CONFIGURATION,
response_model=Dict[str, Any],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_pipeline_configuration(
run_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Dict[str, Any]:
"""Get the pipeline configuration of a specific pipeline run using its ID.
Args:
run_id: ID of the pipeline run to get.
Returns:
The pipeline configuration of the pipeline run.
"""
return zen_store().get_run(run_name_or_id=run_id).config.dict()
get_run(run_id, _=Security(oauth2_authentication))
Get a specific pipeline run using its ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run to get. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponseModel |
The pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}",
response_model=PipelineRunResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run(
run_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> PipelineRunResponseModel:
"""Get a specific pipeline run using its ID.
Args:
run_id: ID of the pipeline run to get.
Returns:
The pipeline run.
"""
return zen_store().get_run(run_name_or_id=run_id)
get_run_dag(run_id, _=Security(oauth2_authentication))
Get the DAG for a given pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run to use to get the DAG. |
required |
Returns:
Type | Description |
---|---|
LineageGraph |
The DAG for a given pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + GRAPH,
response_model=LineageGraph,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run_dag(
run_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> LineageGraph:
"""Get the DAG for a given pipeline run.
Args:
run_id: ID of the pipeline run to use to get the DAG.
Returns:
The DAG for a given pipeline run.
"""
run = zen_store().get_run(run_name_or_id=run_id)
graph = LineageGraph()
graph.generate_run_nodes_and_edges(run)
return graph
get_run_status(run_id, _=Security(oauth2_authentication))
Get the status of a specific pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the pipeline run for which to get the status. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
The status of the pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + STATUS,
response_model=ExecutionStatus,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run_status(
run_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ExecutionStatus:
"""Get the status of a specific pipeline run.
Args:
run_id: ID of the pipeline run for which to get the status.
Returns:
The status of the pipeline run.
"""
return zen_store().get_run(run_id).status
get_run_steps(step_run_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get all steps for a given pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_filter_model |
StepRunFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[StepRunResponseModel] |
The steps for a given pipeline run. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"/{run_id}" + STEPS,
response_model=Page[StepRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_run_steps(
step_run_filter_model: StepRunFilterModel = Depends(
make_dependable(StepRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[StepRunResponseModel]:
"""Get all steps for a given pipeline run.
Args:
step_run_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The steps for a given pipeline run.
"""
return zen_store().list_run_steps(step_run_filter_model)
list_runs(runs_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get pipeline runs according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
runs_filter_model |
PipelineRunFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineRunResponseModel] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.get(
"",
response_model=Page[PipelineRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_runs(
runs_filter_model: PipelineRunFilterModel = Depends(
make_dependable(PipelineRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineRunResponseModel]:
"""Get pipeline runs according to query filters.
Args:
runs_filter_model: Filter model used for pagination, sorting, filtering
Returns:
The pipeline runs according to query filters.
"""
return zen_store().list_runs(runs_filter_model=runs_filter_model)
update_run(run_id, run_model, _=Security(oauth2_authentication))
Updates a run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
ID of the run. |
required |
run_model |
PipelineRunUpdateModel |
Run model to use for the update. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponseModel |
The updated run model. |
Source code in zenml/zen_server/routers/runs_endpoints.py
@router.put(
"/{run_id}",
response_model=PipelineRunResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_run(
run_id: UUID,
run_model: PipelineRunUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> PipelineRunResponseModel:
"""Updates a run.
Args:
run_id: ID of the run.
run_model: Run model to use for the update.
Returns:
The updated run model.
"""
return zen_store().update_run(run_id=run_id, run_update=run_model)
schedule_endpoints
Endpoint definitions for pipeline run schedules.
delete_schedule(schedule_id, _=Security(oauth2_authentication))
Deletes a specific schedule using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_id |
UUID |
ID of the schedule to delete. |
required |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.delete(
"/{schedule_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_schedule(
schedule_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific schedule using its unique id.
Args:
schedule_id: ID of the schedule to delete.
"""
zen_store().delete_schedule(schedule_id=schedule_id)
get_schedule(schedule_id, _=Security(oauth2_authentication))
Gets a specific schedule using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_id |
UUID |
ID of the schedule to get. |
required |
Returns:
Type | Description |
---|---|
ScheduleResponseModel |
A specific schedule object. |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.get(
"/{schedule_id}",
response_model=ScheduleResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_schedule(
schedule_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ScheduleResponseModel:
"""Gets a specific schedule using its unique id.
Args:
schedule_id: ID of the schedule to get.
Returns:
A specific schedule object.
"""
return zen_store().get_schedule(schedule_id=schedule_id)
list_schedules(schedule_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets a list of schedules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_filter_model |
ScheduleFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ScheduleResponseModel] |
List of schedule objects. |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.get(
"",
response_model=Page[ScheduleResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_schedules(
schedule_filter_model: ScheduleFilterModel = Depends(
make_dependable(ScheduleFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ScheduleResponseModel]:
"""Gets a list of schedules.
Args:
schedule_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
List of schedule objects.
"""
return zen_store().list_schedules(
schedule_filter_model=schedule_filter_model
)
update_schedule(schedule_id, schedule_update, _=Security(oauth2_authentication))
Updates the attribute on a specific schedule using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schedule_id |
UUID |
ID of the schedule to get. |
required |
schedule_update |
ScheduleUpdateModel |
the model containing the attributes to update. |
required |
Returns:
Type | Description |
---|---|
ScheduleResponseModel |
The updated schedule object. |
Source code in zenml/zen_server/routers/schedule_endpoints.py
@router.put(
"/{schedule_id}",
response_model=ScheduleResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_schedule(
schedule_id: UUID,
schedule_update: ScheduleUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> ScheduleResponseModel:
"""Updates the attribute on a specific schedule using its unique id.
Args:
schedule_id: ID of the schedule to get.
schedule_update: the model containing the attributes to update.
Returns:
The updated schedule object.
"""
return zen_store().update_schedule(
schedule_id=schedule_id, schedule_update=schedule_update
)
secrets_endpoints
Endpoint definitions for pipeline run secrets.
delete_secret(secret_id, _=Security(oauth2_authentication))
Deletes a specific secret using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_id |
UUID |
ID of the secret to delete. |
required |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.delete(
"/{secret_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_secret(
secret_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific secret using its unique id.
Args:
secret_id: ID of the secret to delete.
"""
zen_store().delete_secret(secret_id=secret_id)
get_secret(secret_id, auth_context=Security(oauth2_authentication))
Gets a specific secret using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_id |
UUID |
ID of the secret to get. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
SecretResponseModel |
A specific secret object. |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.get(
"/{secret_id}",
response_model=SecretResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_secret(
secret_id: UUID,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> SecretResponseModel:
"""Gets a specific secret using its unique id.
Args:
secret_id: ID of the secret to get.
auth_context: Authentication context.
Returns:
A specific secret object.
"""
secret = zen_store().get_secret(secret_id=secret_id)
# Remove secrets from the response if the user does not have write
# permissions.
if PermissionType.WRITE not in auth_context.permissions:
secret.remove_secrets()
return secret
list_secrets(secret_filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
Gets a list of secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_filter_model |
SecretFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[SecretResponseModel] |
List of secret objects. |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.get(
"",
response_model=Page[SecretResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_secrets(
secret_filter_model: SecretFilterModel = Depends(
make_dependable(SecretFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[SecretResponseModel]:
"""Gets a list of secrets.
Args:
secret_filter_model: Filter model used for pagination, sorting,
filtering
auth_context: Authentication context.
Returns:
List of secret objects.
"""
secrets = zen_store().list_secrets(secret_filter_model=secret_filter_model)
# Remove secrets from the response if the user does not have write
# permissions.
if PermissionType.WRITE not in auth_context.permissions:
for secret in secrets.items:
secret.remove_secrets()
return secrets
update_secret(secret_id, secret_update, patch_values=False, _=Security(oauth2_authentication))
Updates the attribute on a specific secret using its unique id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_id |
UUID |
ID of the secret to get. |
required |
secret_update |
SecretUpdateModel |
the model containing the attributes to update. |
required |
patch_values |
Optional[bool] |
Whether to patch the secret values or replace them. |
False |
Returns:
Type | Description |
---|---|
SecretResponseModel |
The updated secret object. |
Source code in zenml/zen_server/routers/secrets_endpoints.py
@router.put(
"/{secret_id}",
response_model=SecretResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_secret(
secret_id: UUID,
secret_update: SecretUpdateModel,
patch_values: Optional[bool] = False,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> SecretResponseModel:
"""Updates the attribute on a specific secret using its unique id.
Args:
secret_id: ID of the secret to get.
secret_update: the model containing the attributes to update.
patch_values: Whether to patch the secret values or replace them.
Returns:
The updated secret object.
"""
if not patch_values:
# If patch_values is False, interpret the update values as a complete
# replacement of the existing secret values. The only adjustment we
# need to make is to set the value of any keys that are not present in
# the update to None, so that they are deleted.
secret = zen_store().get_secret(secret_id=secret_id)
for key in secret.values.keys():
if key not in secret_update.values:
secret_update.values[key] = None
return zen_store().update_secret(
secret_id=secret_id, secret_update=secret_update
)
server_endpoints
Endpoint definitions for authentication (login).
server_info()
Get information about the server.
Returns:
Type | Description |
---|---|
ServerModel |
Information about the server. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get(
INFO,
response_model=ServerModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def server_info() -> ServerModel:
"""Get information about the server.
Returns:
Information about the server.
"""
return zen_store().get_store_info()
version()
Get version of the server.
Returns:
Type | Description |
---|---|
str |
String representing the version of the server. |
Source code in zenml/zen_server/routers/server_endpoints.py
@router.get("/version")
def version() -> str:
"""Get version of the server.
Returns:
String representing the version of the server.
"""
return zenml.__version__
service_connectors_endpoints
Endpoint definitions for service connectors.
delete_service_connector(connector_id, auth_context=Security(oauth2_authentication))
Deletes a service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Exceptions:
Type | Description |
---|---|
KeyError |
If the service connector does not exist or is not accessible. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.delete(
"/{connector_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_service_connector(
connector_id: UUID,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> None:
"""Deletes a service connector.
Args:
connector_id: ID of the service connector.
auth_context: Authentication context.
Raises:
KeyError: If the service connector does not exist or is not accessible.
"""
connector = zen_store().get_service_connector(connector_id)
# Don't allow users to access service connectors that don't belong to them
# unless they are shared.
if (
connector.user
and connector.user.id == auth_context.user.id
or connector.is_shared
):
zen_store().delete_service_connector(connector_id)
return
raise KeyError(f"Service connector with ID {connector_id} not found.")
get_service_connector(connector_id, expand_secrets=True, auth_context=Security(oauth2_authentication))
Returns the requested service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
expand_secrets |
bool |
Whether to expand secrets or not. |
True |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ServiceConnectorResponseModel |
The requested service connector. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the service connector does not exist or is not accessible. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.get(
"/{connector_id}",
response_model=ServiceConnectorResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_connector(
connector_id: UUID,
expand_secrets: bool = True,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> ServiceConnectorResponseModel:
"""Returns the requested service connector.
Args:
connector_id: ID of the service connector.
expand_secrets: Whether to expand secrets or not.
auth_context: Authentication context.
Returns:
The requested service connector.
Raises:
KeyError: If the service connector does not exist or is not accessible.
"""
connector = zen_store().get_service_connector(connector_id)
# Don't allow users to access service connectors that don't belong to them
# unless they are shared.
if (
connector.user
and connector.user.id == auth_context.user.id
or connector.is_shared
):
if PermissionType.WRITE not in auth_context.permissions:
return connector
if expand_secrets and connector.secret_id:
secret = zen_store().get_secret(secret_id=connector.secret_id)
# Update the connector configuration with the secret.
connector.configuration.update(secret.secret_values)
return connector
raise KeyError(f"Service connector with ID {connector_id} not found.")
get_service_connector_client(connector_id, resource_type=None, resource_id=None, auth_context=Security(oauth2_authentication))
Get a service connector client for a service connector and given resource.
This requires the service connector implementation to be installed on the ZenML server, otherwise a 501 Not Implemented error will be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
resource_type |
Optional[str] |
Type of the resource to list. |
None |
resource_id |
Optional[str] |
ID of the resource to list. |
None |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ServiceConnectorResponseModel |
A service connector client that can be used to access the given resource. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the service connector does not exist or is not accessible. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.get(
"/{connector_id}" + SERVICE_CONNECTOR_CLIENT,
response_model=ServiceConnectorResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_connector_client(
connector_id: UUID,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ServiceConnectorResponseModel:
"""Get a service connector client for a service connector and given resource.
This requires the service connector implementation to be installed
on the ZenML server, otherwise a 501 Not Implemented error will be
returned.
Args:
connector_id: ID of the service connector.
resource_type: Type of the resource to list.
resource_id: ID of the resource to list.
auth_context: Authentication context.
Returns:
A service connector client that can be used to access the given
resource.
Raises:
KeyError: If the service connector does not exist or is not accessible.
"""
connector = zen_store().get_service_connector(connector_id)
# Don't allow users to access service connectors that don't belong to them
# unless they are shared.
if (
connector.user
and connector.user.id == auth_context.user.id
or connector.is_shared
):
return zen_store().get_service_connector_client(
service_connector_id=connector_id,
resource_type=resource_type,
resource_id=resource_id,
)
raise KeyError(f"Service connector with ID {connector_id} not found.")
get_service_connector_type(connector_type, _=Security(oauth2_authentication))
Returns the requested service connector type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_type |
str |
the service connector type identifier. |
required |
Returns:
Type | Description |
---|---|
ServiceConnectorTypeModel |
The requested service connector type. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@types_router.get(
"/{connector_type}",
response_model=ServiceConnectorTypeModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_service_connector_type(
connector_type: str,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ServiceConnectorTypeModel:
"""Returns the requested service connector type.
Args:
connector_type: the service connector type identifier.
Returns:
The requested service connector type.
"""
c = zen_store().get_service_connector_type(connector_type)
return c
list_service_connector_types(connector_type=None, resource_type=None, auth_method=None, _=Security(oauth2_authentication))
Get a list of service connector types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_type |
Optional[str] |
Filter by connector type. |
None |
resource_type |
Optional[str] |
Filter by resource type. |
None |
auth_method |
Optional[str] |
Filter by auth method. |
None |
Returns:
Type | Description |
---|---|
List[zenml.models.service_connector_models.ServiceConnectorTypeModel] |
List of service connector types. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@types_router.get(
"",
response_model=List[ServiceConnectorTypeModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_connector_types(
connector_type: Optional[str] = None,
resource_type: Optional[str] = None,
auth_method: Optional[str] = None,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> List[ServiceConnectorTypeModel]:
"""Get a list of service connector types.
Args:
connector_type: Filter by connector type.
resource_type: Filter by resource type.
auth_method: Filter by auth method.
Returns:
List of service connector types.
"""
connector_types = zen_store().list_service_connector_types(
connector_type=connector_type,
resource_type=resource_type,
auth_method=auth_method,
)
return connector_types
list_service_connectors(connector_filter_model=Depends(init_cls_and_handle_errors), expand_secrets=True, auth_context=Security(oauth2_authentication))
Get a list of all service connectors for a specific type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_filter_model |
ServiceConnectorFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
expand_secrets |
bool |
Whether to expand secrets or not. |
True |
auth_context |
AuthContext |
Authentication Context |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ServiceConnectorResponseModel] |
Page with list of service connectors for a specific type. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.get(
"",
response_model=Page[ServiceConnectorResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_connectors(
connector_filter_model: ServiceConnectorFilterModel = Depends(
make_dependable(ServiceConnectorFilterModel)
),
expand_secrets: bool = True,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[ServiceConnectorResponseModel]:
"""Get a list of all service connectors for a specific type.
Args:
connector_filter_model: Filter model used for pagination, sorting,
filtering
expand_secrets: Whether to expand secrets or not.
auth_context: Authentication Context
Returns:
Page with list of service connectors for a specific type.
"""
connector_filter_model.set_scope_user(user_id=auth_context.user.id)
connectors = zen_store().list_service_connectors(
filter_model=connector_filter_model
)
if expand_secrets and PermissionType.WRITE in auth_context.permissions:
for connector in connectors.items:
if not connector.secret_id:
continue
secret = zen_store().get_secret(secret_id=connector.secret_id)
# Update the connector configuration with the secret.
connector.configuration.update(secret.secret_values)
return connectors
update_service_connector(connector_id, connector_update, auth_context=Security(oauth2_authentication))
Updates a service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
ID of the service connector. |
required |
connector_update |
ServiceConnectorUpdateModel |
Service connector to use to update. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ServiceConnectorResponseModel |
Updated service connector. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the service connector does not exist or is not accessible. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.put(
"/{connector_id}",
response_model=ServiceConnectorResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_service_connector(
connector_id: UUID,
connector_update: ServiceConnectorUpdateModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ServiceConnectorResponseModel:
"""Updates a service connector.
Args:
connector_id: ID of the service connector.
connector_update: Service connector to use to update.
auth_context: Authentication context.
Returns:
Updated service connector.
Raises:
KeyError: If the service connector does not exist or is not accessible.
"""
connector = zen_store().get_service_connector(connector_id)
# Don't allow users to access service connectors that don't belong to them
# unless they are shared.
if (
connector.user
and connector.user.id == auth_context.user.id
or connector.is_shared
):
return zen_store().update_service_connector(
service_connector_id=connector_id,
update=connector_update,
)
raise KeyError(f"Service connector with ID {connector_id} not found.")
validate_and_verify_service_connector(connector_id, resource_type=None, resource_id=None, list_resources=True, auth_context=Security(oauth2_authentication))
Verifies if a service connector instance has access to one or more resources.
This requires the service connector implementation to be installed on the ZenML server, otherwise a 501 Not Implemented error will be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector_id |
UUID |
The ID of the service connector to verify. |
required |
resource_type |
Optional[str] |
The type of resource to verify access to. |
None |
resource_id |
Optional[str] |
The ID of the resource to verify access to. |
None |
list_resources |
bool |
If True, the list of all resources accessible through the service connector and matching the supplied resource type and ID are returned. |
True |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ServiceConnectorResourcesModel |
The list of resources that the service connector has access to, scoped to the supplied resource type and ID, if provided. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the service connector does not exist or is not accessible. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.put(
"/{connector_id}" + SERVICE_CONNECTOR_VERIFY,
response_model=ServiceConnectorResourcesModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def validate_and_verify_service_connector(
connector_id: UUID,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
list_resources: bool = True,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> ServiceConnectorResourcesModel:
"""Verifies if a service connector instance has access to one or more resources.
This requires the service connector implementation to be installed
on the ZenML server, otherwise a 501 Not Implemented error will be
returned.
Args:
connector_id: The ID of the service connector to verify.
resource_type: The type of resource to verify access to.
resource_id: The ID of the resource to verify access to.
list_resources: If True, the list of all resources accessible
through the service connector and matching the supplied resource
type and ID are returned.
auth_context: Authentication context.
Returns:
The list of resources that the service connector has access to, scoped
to the supplied resource type and ID, if provided.
Raises:
KeyError: If the service connector does not exist or is not accessible.
"""
connector = zen_store().get_service_connector(connector_id)
# Don't allow users to access service connectors that don't belong to them
# unless they are shared.
if (
connector.user
and connector.user.id == auth_context.user.id
or connector.is_shared
):
return zen_store().verify_service_connector(
service_connector_id=connector_id,
resource_type=resource_type,
resource_id=resource_id,
list_resources=list_resources,
)
raise KeyError(f"Service connector with ID {connector_id} not found.")
validate_and_verify_service_connector_config(connector, list_resources=True, _=Security(oauth2_authentication))
Verifies if a service connector configuration has access to resources.
This requires the service connector implementation to be installed on the ZenML server, otherwise a 501 Not Implemented error will be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connector |
ServiceConnectorRequestModel |
The service connector configuration to verify. |
required |
list_resources |
bool |
If True, the list of all resources accessible through the service connector is returned. |
True |
Returns:
Type | Description |
---|---|
ServiceConnectorResourcesModel |
The list of resources that the service connector configuration has access to. |
Source code in zenml/zen_server/routers/service_connectors_endpoints.py
@router.post(
SERVICE_CONNECTOR_VERIFY,
response_model=ServiceConnectorResourcesModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def validate_and_verify_service_connector_config(
connector: ServiceConnectorRequestModel,
list_resources: bool = True,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ServiceConnectorResourcesModel:
"""Verifies if a service connector configuration has access to resources.
This requires the service connector implementation to be installed
on the ZenML server, otherwise a 501 Not Implemented error will be
returned.
Args:
connector: The service connector configuration to verify.
list_resources: If True, the list of all resources accessible
through the service connector is returned.
Returns:
The list of resources that the service connector configuration has
access to.
"""
return zen_store().verify_service_connector_config(
service_connector=connector,
list_resources=list_resources,
)
stack_components_endpoints
Endpoint definitions for stack components.
deregister_stack_component(component_id, _=Security(oauth2_authentication))
Deletes a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
ID of the stack component. |
required |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.delete(
"/{component_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def deregister_stack_component(
component_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a stack component.
Args:
component_id: ID of the stack component.
"""
zen_store().delete_stack_component(component_id)
get_stack_component(component_id, _=Security(oauth2_authentication))
Returns the requested stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
ID of the stack component. |
required |
Returns:
Type | Description |
---|---|
ComponentResponseModel |
The requested stack component. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.get(
"/{component_id}",
response_model=ComponentResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_stack_component(
component_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ComponentResponseModel:
"""Returns the requested stack component.
Args:
component_id: ID of the stack component.
Returns:
The requested stack component.
"""
return zen_store().get_stack_component(component_id)
get_stack_component_types(_=Security(oauth2_authentication))
Get a list of all stack component types.
Returns:
Type | Description |
---|---|
List[str] |
List of stack components. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@types_router.get(
"",
response_model=List[str],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_stack_component_types(
_: AuthContext = Security(authorize, scopes=[PermissionType.READ])
) -> List[str]:
"""Get a list of all stack component types.
Returns:
List of stack components.
"""
return StackComponentType.values()
list_stack_components(component_filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
Get a list of all stack components for a specific type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_filter_model |
ComponentFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
Authentication Context |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ComponentResponseModel] |
List of stack components for a specific type. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.get(
"",
response_model=Page[ComponentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_stack_components(
component_filter_model: ComponentFilterModel = Depends(
make_dependable(ComponentFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[ComponentResponseModel]:
"""Get a list of all stack components for a specific type.
Args:
component_filter_model: Filter model used for pagination, sorting,
filtering
auth_context: Authentication Context
Returns:
List of stack components for a specific type.
"""
component_filter_model.set_scope_user(user_id=auth_context.user.id)
return zen_store().list_stack_components(
component_filter_model=component_filter_model
)
update_stack_component(component_id, component_update, _=Security(oauth2_authentication))
Updates a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
ID of the stack component. |
required |
component_update |
ComponentUpdateModel |
Stack component to use to update. |
required |
Returns:
Type | Description |
---|---|
ComponentResponseModel |
Updated stack component. |
Source code in zenml/zen_server/routers/stack_components_endpoints.py
@router.put(
"/{component_id}",
response_model=ComponentResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_stack_component(
component_id: UUID,
component_update: ComponentUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> ComponentResponseModel:
"""Updates a stack component.
Args:
component_id: ID of the stack component.
component_update: Stack component to use to update.
Returns:
Updated stack component.
"""
return zen_store().update_stack_component(
component_id=component_id,
component_update=component_update,
)
stacks_endpoints
Endpoint definitions for stacks.
delete_stack(stack_id, _=Security(oauth2_authentication))
Deletes a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_id |
UUID |
Name of the stack. |
required |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.delete(
"/{stack_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_stack(
stack_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a stack.
Args:
stack_id: Name of the stack.
"""
zen_store().delete_stack(stack_id) # aka 'delete_stack'
get_stack(stack_id, _=Security(oauth2_authentication))
Returns the requested stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_id |
UUID |
ID of the stack. |
required |
Returns:
Type | Description |
---|---|
StackResponseModel |
The requested stack. |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.get(
"/{stack_id}",
response_model=StackResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_stack(
stack_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> StackResponseModel:
"""Returns the requested stack.
Args:
stack_id: ID of the stack.
Returns:
The requested stack.
"""
return zen_store().get_stack(stack_id)
list_stacks(stack_filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
Returns all stacks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_filter_model |
StackFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
Authentication Context |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[StackResponseModel] |
All stacks. |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.get(
"",
response_model=Page[StackResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_stacks(
stack_filter_model: StackFilterModel = Depends(
make_dependable(StackFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[StackResponseModel]:
"""Returns all stacks.
Args:
stack_filter_model: Filter model used for pagination, sorting, filtering
auth_context: Authentication Context
Returns:
All stacks.
"""
stack_filter_model.set_scope_user(user_id=auth_context.user.id)
return zen_store().list_stacks(stack_filter_model=stack_filter_model)
update_stack(stack_id, stack_update, _=Security(oauth2_authentication))
Updates a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_id |
UUID |
Name of the stack. |
required |
stack_update |
StackUpdateModel |
Stack to use for the update. |
required |
Returns:
Type | Description |
---|---|
StackResponseModel |
The updated stack. |
Source code in zenml/zen_server/routers/stacks_endpoints.py
@router.put(
"/{stack_id}",
response_model=StackResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_stack(
stack_id: UUID,
stack_update: StackUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> StackResponseModel:
"""Updates a stack.
Args:
stack_id: Name of the stack.
stack_update: Stack to use for the update.
Returns:
The updated stack.
"""
return zen_store().update_stack(
stack_id=stack_id,
stack_update=stack_update,
)
steps_endpoints
Endpoint definitions for steps (and artifacts) of pipeline runs.
create_run_step(step, _=Security(oauth2_authentication))
Create a run step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
StepRunRequestModel |
The run step to create. |
required |
Returns:
Type | Description |
---|---|
StepRunResponseModel |
The created run step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.post(
"",
response_model=StepRunResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_run_step(
step: StepRunRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> StepRunResponseModel:
"""Create a run step.
Args:
step: The run step to create.
Returns:
The created run step.
"""
return zen_store().create_run_step(step_run=step)
get_step(step_id, _=Security(oauth2_authentication))
Get one specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step to get. |
required |
Returns:
Type | Description |
---|---|
StepRunResponseModel |
The step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}",
response_model=StepRunResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step(
step_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> StepRunResponseModel:
"""Get one specific step.
Args:
step_id: ID of the step to get.
Returns:
The step.
"""
return zen_store().get_run_step(step_id)
get_step_configuration(step_id, _=Security(oauth2_authentication))
Get the configuration of a specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step to get. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The step configuration. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}" + STEP_CONFIGURATION,
response_model=Dict[str, Any],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step_configuration(
step_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Dict[str, Any]:
"""Get the configuration of a specific step.
Args:
step_id: ID of the step to get.
Returns:
The step configuration.
"""
return zen_store().get_run_step(step_id).config.dict()
get_step_logs(step_id, _=Security(oauth2_authentication))
Get the logs of a specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step for which to get the logs. |
required |
Returns:
Type | Description |
---|---|
str |
The logs of the step. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If no logs are available for this step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}" + LOGS,
response_model=str,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step_logs(
step_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> str:
"""Get the logs of a specific step.
Args:
step_id: ID of the step for which to get the logs.
Returns:
The logs of the step.
Raises:
HTTPException: If no logs are available for this step.
"""
store = zen_store()
logs = store.get_run_step(step_id).logs
if logs is None:
raise HTTPException(
status_code=404, detail="No logs available for this step"
)
artifact_store = _load_artifact_store(logs.artifact_store_id, store)
return str(
_load_file_from_artifact_store(
logs.uri, artifact_store=artifact_store, mode="r"
)
)
get_step_status(step_id, _=Security(oauth2_authentication))
Get the status of a specific step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step for which to get the status. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
The status of the step. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"/{step_id}" + STATUS,
response_model=ExecutionStatus,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_step_status(
step_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> ExecutionStatus:
"""Get the status of a specific step.
Args:
step_id: ID of the step for which to get the status.
Returns:
The status of the step.
"""
return zen_store().get_run_step(step_id).status
list_run_steps(step_run_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get run steps according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_filter_model |
StepRunFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[StepRunResponseModel] |
The run steps according to query filters. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.get(
"",
response_model=Page[StepRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_run_steps(
step_run_filter_model: StepRunFilterModel = Depends(
make_dependable(StepRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[StepRunResponseModel]:
"""Get run steps according to query filters.
Args:
step_run_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The run steps according to query filters.
"""
return zen_store().list_run_steps(
step_run_filter_model=step_run_filter_model
)
update_step(step_id, step_model, _=Security(oauth2_authentication))
Updates a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
UUID |
ID of the step. |
required |
step_model |
StepRunUpdateModel |
Step model to use for the update. |
required |
Returns:
Type | Description |
---|---|
StepRunResponseModel |
The updated step model. |
Source code in zenml/zen_server/routers/steps_endpoints.py
@router.put(
"/{step_id}",
response_model=StepRunResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_step(
step_id: UUID,
step_model: StepRunUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> StepRunResponseModel:
"""Updates a step.
Args:
step_id: ID of the step.
step_model: Step model to use for the update.
Returns:
The updated step model.
"""
return zen_store().update_run_step(
step_run_id=step_id, step_run_update=step_model
)
team_role_assignments_endpoints
Endpoint definitions for role assignments.
create_team_role_assignment(role_assignment, _=Security(oauth2_authentication))
Creates a role assignment.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_assignment |
TeamRoleAssignmentRequestModel |
Role assignment to create. |
required |
Returns:
Type | Description |
---|---|
TeamRoleAssignmentResponseModel |
The created role assignment. |
Source code in zenml/zen_server/routers/team_role_assignments_endpoints.py
@router.post(
"",
response_model=TeamRoleAssignmentResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_team_role_assignment(
role_assignment: TeamRoleAssignmentRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> TeamRoleAssignmentResponseModel:
"""Creates a role assignment.
# noqa: DAR401
Args:
role_assignment: Role assignment to create.
Returns:
The created role assignment.
"""
return zen_store().create_team_role_assignment(
team_role_assignment=role_assignment
)
delete_team_role_assignment(role_assignment_id, _=Security(oauth2_authentication))
Deletes a specific role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_assignment_id |
UUID |
The ID of the role assignment. |
required |
Source code in zenml/zen_server/routers/team_role_assignments_endpoints.py
@router.delete(
"/{role_assignment_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_team_role_assignment(
role_assignment_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific role.
Args:
role_assignment_id: The ID of the role assignment.
"""
zen_store().delete_team_role_assignment(
team_role_assignment_id=role_assignment_id
)
get_team_role_assignment(role_assignment_id, _=Security(oauth2_authentication))
Returns a specific role assignment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_assignment_id |
UUID |
Name or ID of the role assignment. |
required |
Returns:
Type | Description |
---|---|
TeamRoleAssignmentResponseModel |
A specific role assignment. |
Source code in zenml/zen_server/routers/team_role_assignments_endpoints.py
@router.get(
"/{role_assignment_id}",
response_model=TeamRoleAssignmentResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_team_role_assignment(
role_assignment_id: UUID,
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> TeamRoleAssignmentResponseModel:
"""Returns a specific role assignment.
Args:
role_assignment_id: Name or ID of the role assignment.
Returns:
A specific role assignment.
"""
return zen_store().get_team_role_assignment(
team_role_assignment_id=role_assignment_id
)
list_team_role_assignments(team_role_assignment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all role assignments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_role_assignment_filter_model |
TeamRoleAssignmentFilterModel |
filter models for team role assignments |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[TeamRoleAssignmentResponseModel] |
List of all role assignments. |
Source code in zenml/zen_server/routers/team_role_assignments_endpoints.py
@router.get(
"",
response_model=Page[TeamRoleAssignmentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_team_role_assignments(
team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends(
make_dependable(TeamRoleAssignmentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[TeamRoleAssignmentResponseModel]:
"""Returns a list of all role assignments.
Args:
team_role_assignment_filter_model: filter models for team role assignments
Returns:
List of all role assignments.
"""
return zen_store().list_team_role_assignments(
team_role_assignment_filter_model=team_role_assignment_filter_model
)
teams_endpoints
Endpoint definitions for teams and team membership.
create_team(team, _=Security(oauth2_authentication))
Creates a team.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team |
TeamRequestModel |
Team to create. |
required |
Returns:
Type | Description |
---|---|
TeamResponseModel |
The created team. |
Source code in zenml/zen_server/routers/teams_endpoints.py
@router.post(
"",
response_model=TeamResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_team(
team: TeamRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> TeamResponseModel:
"""Creates a team.
# noqa: DAR401
Args:
team: Team to create.
Returns:
The created team.
"""
return zen_store().create_team(team=team)
delete_team(team_name_or_id, _=Security(oauth2_authentication))
Deletes a specific team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the team. |
required |
Source code in zenml/zen_server/routers/teams_endpoints.py
@router.delete(
"/{team_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_team(
team_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a specific team.
Args:
team_name_or_id: Name or ID of the team.
"""
zen_store().delete_team(team_name_or_id=team_name_or_id)
get_team(team_name_or_id, _=Security(oauth2_authentication))
Returns a specific team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the team. |
required |
Returns:
Type | Description |
---|---|
TeamResponseModel |
A specific team. |
Source code in zenml/zen_server/routers/teams_endpoints.py
@router.get(
"/{team_name_or_id}",
response_model=TeamResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_team(
team_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> TeamResponseModel:
"""Returns a specific team.
Args:
team_name_or_id: Name or ID of the team.
Returns:
A specific team.
"""
return zen_store().get_team(team_name_or_id=team_name_or_id)
list_role_assignments_for_team(team_role_assignment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all roles that are assigned to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_role_assignment_filter_model |
TeamRoleAssignmentFilterModel |
All filter parameters including pagination params. |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[TeamRoleAssignmentResponseModel] |
A list of all roles that are assigned to a team. |
Source code in zenml/zen_server/routers/teams_endpoints.py
@router.get(
"/{team_name_or_id}" + ROLES,
response_model=Page[TeamRoleAssignmentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_role_assignments_for_team(
team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends(
make_dependable(TeamRoleAssignmentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[TeamRoleAssignmentResponseModel]:
"""Returns a list of all roles that are assigned to a team.
Args:
team_role_assignment_filter_model: All filter parameters including
pagination params.
Returns:
A list of all roles that are assigned to a team.
"""
return zen_store().list_team_role_assignments(
team_role_assignment_filter_model
)
list_teams(team_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all teams.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_filter_model |
TeamFilterModel |
All filter parameters including pagination params. |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[TeamResponseModel] |
List of all teams. |
Source code in zenml/zen_server/routers/teams_endpoints.py
@router.get(
"",
response_model=Page[TeamResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_teams(
team_filter_model: TeamFilterModel = Depends(
make_dependable(TeamFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[TeamResponseModel]:
"""Returns a list of all teams.
Args:
team_filter_model: All filter parameters including pagination params.
Returns:
List of all teams.
"""
return zen_store().list_teams(team_filter_model)
update_team(team_id, team_update, _=Security(oauth2_authentication))
Updates a team.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_id |
UUID |
ID of the team to update. |
required |
team_update |
TeamUpdateModel |
Team update. |
required |
Returns:
Type | Description |
---|---|
TeamResponseModel |
The updated team. |
Source code in zenml/zen_server/routers/teams_endpoints.py
@router.put(
"/{team_id}",
response_model=TeamResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def update_team(
team_id: UUID,
team_update: TeamUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> TeamResponseModel:
"""Updates a team.
# noqa: DAR401
Args:
team_id: ID of the team to update.
team_update: Team update.
Returns:
The updated team.
"""
return zen_store().update_team(team_id=team_id, team_update=team_update)
users_endpoints
Endpoint definitions for users.
activate_user(user_name_or_id, user_update)
Activates a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
user_update |
UserUpdateModel |
the user to use for the update. |
required |
Returns:
Type | Description |
---|---|
UserResponseModel |
The updated user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@activation_router.put(
"/{user_name_or_id}" + ACTIVATE,
response_model=UserResponseModel,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def activate_user(
user_name_or_id: Union[str, UUID],
user_update: UserUpdateModel,
) -> UserResponseModel:
"""Activates a specific user.
Args:
user_name_or_id: Name or ID of the user.
user_update: the user to use for the update.
Returns:
The updated user.
"""
user = zen_store().get_user(user_name_or_id)
authenticate_credentials(
user_name_or_id=user_name_or_id,
activation_token=user_update.activation_token,
)
user_update.active = True
user_update.activation_token = None
return zen_store().update_user(
user_id=user.id, user_update=user_update
)
create_user(user, _=Security(oauth2_authentication))
Creates a user.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserRequestModel |
User to create. |
required |
Returns:
Type | Description |
---|---|
UserResponseModel |
The created user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.post(
"",
response_model=UserResponseModel,
responses={
401: error_response,
409: error_response,
422: error_response,
},
)
@handle_exceptions
def create_user(
user: UserRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> UserResponseModel:
"""Creates a user.
# noqa: DAR401
Args:
user: User to create.
Returns:
The created user.
"""
# Two ways of creating a new user:
# 1. Create a new user with a password and have it immediately active
# 2. Create a new user without a password and have it activated at a
# later time with an activation token
token: Optional[str] = None
if user.password is None:
user.active = False
token = user.generate_activation_token()
else:
user.active = True
new_user = zen_store().create_user(user)
# add back the original unhashed activation token, if generated, to
# send it back to the client
if token:
new_user.activation_token = token
return new_user
deactivate_user(user_name_or_id, _=Security(oauth2_authentication))
Deactivates a user and generates a new activation token for it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
Returns:
Type | Description |
---|---|
UserResponseModel |
The generated activation token. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.put(
"/{user_name_or_id}" + DEACTIVATE,
response_model=UserResponseModel,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def deactivate_user(
user_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> UserResponseModel:
"""Deactivates a user and generates a new activation token for it.
Args:
user_name_or_id: Name or ID of the user.
Returns:
The generated activation token.
"""
user = zen_store().get_user(user_name_or_id)
user_update = UserUpdateModel(
name=user.name,
active=False,
)
token = user_update.generate_activation_token()
user = zen_store().update_user(
user_id=user.id, user_update=user_update
)
# add back the original unhashed activation token
user.activation_token = token
return user
delete_user(user_name_or_id, auth_context=Security(oauth2_authentication))
Deletes a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the user is not authorized to delete the user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.delete(
"/{user_name_or_id}",
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def delete_user(
user_name_or_id: Union[str, UUID],
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> None:
"""Deletes a specific user.
Args:
user_name_or_id: Name or ID of the user.
auth_context: The authentication context.
Raises:
IllegalOperationError: If the user is not authorized to delete the user.
"""
user = zen_store().get_user(user_name_or_id)
if auth_context.user.name == user.name:
raise IllegalOperationError(
"You cannot delete the user account currently used to authenticate "
"to the ZenML server. If you wish to delete this account, "
"please authenticate with another account or contact your ZenML "
"administrator."
)
zen_store().delete_user(user_name_or_id=user_name_or_id)
email_opt_in_response(user_name_or_id, user_response, auth_context=Security(oauth2_authentication))
Sets the response of the user to the email prompt.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
user_response |
UserUpdateModel |
User Response to email prompt |
required |
auth_context |
AuthContext |
The authentication context of the user |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponseModel |
The updated user. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
if the user does not have the required permissions |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.put(
"/{user_name_or_id}" + EMAIL_ANALYTICS,
response_model=UserResponseModel,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def email_opt_in_response(
user_name_or_id: Union[str, UUID],
user_response: UserUpdateModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.ME]
),
) -> UserResponseModel:
"""Sets the response of the user to the email prompt.
Args:
user_name_or_id: Name or ID of the user.
user_response: User Response to email prompt
auth_context: The authentication context of the user
Returns:
The updated user.
Raises:
AuthorizationException: if the user does not have the required
permissions
"""
user = zen_store().get_user(user_name_or_id)
if str(auth_context.user.id) == str(user_name_or_id):
user_update = UserUpdateModel(
name=user.name,
email=user_response.email,
email_opted_in=user_response.email_opted_in,
)
if user_response.email_opted_in is not None:
email_opt_int(
opted_in=user_response.email_opted_in,
email=user_response.email,
source="zenml server",
)
return zen_store().update_user(
user_id=user.id, user_update=user_update
)
else:
raise AuthorizationException(
"Users can not opt in on behalf of another " "user."
)
get_current_user(auth_context=Security(oauth2_authentication))
Returns the model of the authenticated user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponseModel |
The model of the authenticated user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@current_user_router.get(
"/current-user",
response_model=UserResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_current_user(
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> UserResponseModel:
"""Returns the model of the authenticated user.
Args:
auth_context: The authentication context.
Returns:
The model of the authenticated user.
"""
return auth_context.user
get_user(user_name_or_id, _=Security(oauth2_authentication))
Returns a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
Returns:
Type | Description |
---|---|
UserResponseModel |
A specific user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.get(
"/{user_name_or_id}",
response_model=UserResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_user(
user_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> UserResponseModel:
"""Returns a specific user.
Args:
user_name_or_id: Name or ID of the user.
Returns:
A specific user.
"""
return zen_store().get_user(user_name_or_id=user_name_or_id)
list_role_assignments_for_user(user_role_assignment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all roles that are assigned to a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_role_assignment_filter_model |
UserRoleAssignmentFilterModel |
filter models for user role assignments |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[UserRoleAssignmentResponseModel] |
A list of all roles that are assigned to a user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.get(
"/{user_name_or_id}" + ROLES,
response_model=Page[UserRoleAssignmentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_role_assignments_for_user(
user_role_assignment_filter_model: UserRoleAssignmentFilterModel = Depends(
make_dependable(UserRoleAssignmentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[UserRoleAssignmentResponseModel]:
"""Returns a list of all roles that are assigned to a user.
Args:
user_role_assignment_filter_model: filter models for user role assignments
Returns:
A list of all roles that are assigned to a user.
"""
return zen_store().list_user_role_assignments(
user_role_assignment_filter_model=user_role_assignment_filter_model
)
list_users(user_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all users.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_filter_model |
UserFilterModel |
Model that takes care of filtering, sorting and pagination |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[UserResponseModel] |
A list of all users. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.get(
"",
response_model=Page[UserResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_users(
user_filter_model: UserFilterModel = Depends(
make_dependable(UserFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[UserResponseModel]:
"""Returns a list of all users.
Args:
user_filter_model: Model that takes care of filtering, sorting and pagination
Returns:
A list of all users.
"""
return zen_store().list_users(user_filter_model=user_filter_model)
update_myself(user, auth_context=Security(oauth2_authentication))
Updates a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user |
UserUpdateModel |
the user to use for the update. |
required |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
UserResponseModel |
The updated user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@current_user_router.put(
"/current-user",
response_model=UserResponseModel,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_myself(
user: UserUpdateModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.ME]
),
) -> UserResponseModel:
"""Updates a specific user.
Args:
user: the user to use for the update.
auth_context: The authentication context.
Returns:
The updated user.
"""
return zen_store().update_user(
user_id=auth_context.user.id, user_update=user
)
update_user(user_name_or_id, user_update, _=Security(oauth2_authentication))
Updates a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the user. |
required |
user_update |
UserUpdateModel |
the user to use for the update. |
required |
Returns:
Type | Description |
---|---|
UserResponseModel |
The updated user. |
Source code in zenml/zen_server/routers/users_endpoints.py
@router.put(
"/{user_name_or_id}",
response_model=UserResponseModel,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_user(
user_name_or_id: Union[str, UUID],
user_update: UserUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> UserResponseModel:
"""Updates a specific user.
Args:
user_name_or_id: Name or ID of the user.
user_update: the user to use for the update.
Returns:
The updated user.
"""
user = zen_store().get_user(user_name_or_id)
return zen_store().update_user(
user_id=user.id,
user_update=user_update,
)
workspaces_endpoints
Endpoint definitions for workspaces.
create_build(workspace_name_or_id, build, auth_context=Security(oauth2_authentication))
Creates a build.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
build |
PipelineBuildRequestModel |
Build to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
PipelineBuildResponseModel |
The created build. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the build does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_BUILDS,
response_model=PipelineBuildResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_build(
workspace_name_or_id: Union[str, UUID],
build: PipelineBuildRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> PipelineBuildResponseModel:
"""Creates a build.
Args:
workspace_name_or_id: Name or ID of the workspace.
build: Build to create.
auth_context: Authentication context.
Returns:
The created build.
Raises:
IllegalOperationError: If the workspace or user specified in the build
does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if build.workspace != workspace.id:
raise IllegalOperationError(
"Creating builds outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if build.user != auth_context.user.id:
raise IllegalOperationError(
"Creating builds for a user other than yourself "
"is not supported."
)
return zen_store().create_build(build=build)
create_code_repository(workspace_name_or_id, code_repository, auth_context=Security(oauth2_authentication))
Creates a code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
code_repository |
CodeRepositoryRequestModel |
Code repository to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
CodeRepositoryResponseModel |
The created code repository. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the code repository does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + CODE_REPOSITORIES,
response_model=CodeRepositoryResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_code_repository(
workspace_name_or_id: Union[str, UUID],
code_repository: CodeRepositoryRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> CodeRepositoryResponseModel:
"""Creates a code repository.
Args:
workspace_name_or_id: Name or ID of the workspace.
code_repository: Code repository to create.
auth_context: Authentication context.
Returns:
The created code repository.
Raises:
IllegalOperationError: If the workspace or user specified in the
code repository does not match the current workspace or
authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if code_repository.workspace != workspace.id:
raise IllegalOperationError(
"Creating code repositories outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if code_repository.user != auth_context.user.id:
raise IllegalOperationError(
"Creating code repositories for a user other than yourself "
"is not supported."
)
return zen_store().create_code_repository(code_repository=code_repository)
create_deployment(workspace_name_or_id, deployment, auth_context=Security(oauth2_authentication))
Creates a deployment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
deployment |
PipelineDeploymentRequestModel |
Deployment to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
PipelineDeploymentResponseModel |
The created deployment. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the deployment does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_DEPLOYMENTS,
response_model=PipelineDeploymentResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_deployment(
workspace_name_or_id: Union[str, UUID],
deployment: PipelineDeploymentRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> PipelineDeploymentResponseModel:
"""Creates a deployment.
Args:
workspace_name_or_id: Name or ID of the workspace.
deployment: Deployment to create.
auth_context: Authentication context.
Returns:
The created deployment.
Raises:
IllegalOperationError: If the workspace or user specified in the
deployment does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if deployment.workspace != workspace.id:
raise IllegalOperationError(
"Creating deployments outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if deployment.user != auth_context.user.id:
raise IllegalOperationError(
"Creating deployments for a user other than yourself "
"is not supported."
)
return zen_store().create_deployment(deployment=deployment)
create_model(workspace_name_or_id, model, auth_context=Security(oauth2_authentication))
Create a new model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model |
ModelRequestModel |
The model to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ModelResponseModel |
The created model. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the model does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + MODELS,
response_model=ModelResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model(
workspace_name_or_id: Union[str, UUID],
model: ModelRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ModelResponseModel:
"""Create a new model.
Args:
workspace_name_or_id: Name or ID of the workspace.
model: The model to create.
auth_context: Authentication context.
Returns:
The created model.
Raises:
IllegalOperationError: If the workspace or user specified in the
model does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if model.workspace != workspace.id:
raise IllegalOperationError(
"Creating models outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if model.user != auth_context.user.id:
raise IllegalOperationError(
"Creating models for a user other than yourself "
"is not supported."
)
return zen_store().create_model(model)
create_model_version(workspace_name_or_id, model_name_or_id, model_version, auth_context=Security(oauth2_authentication))
Create a new model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model. |
required |
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version |
ModelVersionRequestModel |
The model version to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ModelVersionResponseModel |
The created model version. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the model version does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODELS
+ "/{model_name_or_id}"
+ MODEL_VERSIONS,
response_model=ModelVersionResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version(
workspace_name_or_id: Union[str, UUID],
model_name_or_id: Union[str, UUID],
model_version: ModelVersionRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ModelVersionResponseModel:
"""Create a new model version.
Args:
model_name_or_id: Name or ID of the model.
workspace_name_or_id: Name or ID of the workspace.
model_version: The model version to create.
auth_context: Authentication context.
Returns:
The created model version.
Raises:
IllegalOperationError: If the workspace or user specified in the
model version does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if model_version.workspace != workspace.id:
raise IllegalOperationError(
"Creating model versions outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if model_version.user != auth_context.user.id:
raise IllegalOperationError(
"Creating models for a user other than yourself "
"is not supported."
)
mv = zen_store().create_model_version(model_version)
return mv
create_model_version_artifact_link(workspace_name_or_id, model_name_or_id, model_version_name_or_id, model_version_artifact_link, auth_context=Security(oauth2_authentication))
Create a new model version to artifact link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model. |
required |
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model version. |
required |
model_version_artifact_link |
ModelVersionArtifactRequestModel |
The model version to artifact link to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ModelVersionArtifactResponseModel |
The created model version to artifact link. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the model version does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODELS
+ "/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ ARTIFACTS,
response_model=ModelVersionArtifactResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_artifact_link(
workspace_name_or_id: Union[str, UUID],
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_artifact_link: ModelVersionArtifactRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ModelVersionArtifactResponseModel:
"""Create a new model version to artifact link.
Args:
model_name_or_id: Name or ID of the model.
workspace_name_or_id: Name or ID of the workspace.
model_version_name_or_id: Name or ID of the model version.
model_version_artifact_link: The model version to artifact link to create.
auth_context: Authentication context.
Returns:
The created model version to artifact link.
Raises:
IllegalOperationError: If the workspace or user specified in the
model version does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if model_version_artifact_link.workspace != workspace.id:
raise IllegalOperationError(
"Creating model version to artifact links outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if model_version_artifact_link.user != auth_context.user.id:
raise IllegalOperationError(
"Creating model to artifact links for a user other than yourself "
"is not supported."
)
mv = zen_store().create_model_version_artifact_link(
model_version_artifact_link
)
return mv
create_model_version_pipeline_run_link(workspace_name_or_id, model_name_or_id, model_version_name_or_id, model_version_pipeline_run_link, auth_context=Security(oauth2_authentication))
Create a new model version to pipeline run link.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model. |
required |
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model version. |
required |
model_version_pipeline_run_link |
ModelVersionPipelineRunRequestModel |
The model version to pipeline run link to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ModelVersionPipelineRunResponseModel |
|
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the model version does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODELS
+ "/{model_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ RUNS,
response_model=ModelVersionPipelineRunResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_pipeline_run_link(
workspace_name_or_id: Union[str, UUID],
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ModelVersionPipelineRunResponseModel:
"""Create a new model version to pipeline run link.
Args:
model_name_or_id: Name or ID of the model.
workspace_name_or_id: Name or ID of the workspace.
model_version_name_or_id: Name or ID of the model version.
model_version_pipeline_run_link: The model version to pipeline run link to create.
auth_context: Authentication context.
Returns:
- If Model Version to Pipeline Run Link already exists - returns the existing link.
- Otherwise, returns the newly created model version to pipeline run link.
Raises:
IllegalOperationError: If the workspace or user specified in the
model version does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if model_version_pipeline_run_link.workspace != workspace.id:
raise IllegalOperationError(
"Creating model versions outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if model_version_pipeline_run_link.user != auth_context.user.id:
raise IllegalOperationError(
"Creating models for a user other than yourself "
"is not supported."
)
mv = zen_store().create_model_version_pipeline_run_link(
model_version_pipeline_run_link
)
return mv
create_pipeline(workspace_name_or_id, pipeline, auth_context=Security(oauth2_authentication))
Creates a pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline |
PipelineRequestModel |
Pipeline to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
PipelineResponseModel |
The created pipeline. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the pipeline does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINES,
response_model=PipelineResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_pipeline(
workspace_name_or_id: Union[str, UUID],
pipeline: PipelineRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> PipelineResponseModel:
"""Creates a pipeline.
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline: Pipeline to create.
auth_context: Authentication context.
Returns:
The created pipeline.
Raises:
IllegalOperationError: If the workspace or user specified in the pipeline
does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if pipeline.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipelines outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if pipeline.user != auth_context.user.id:
raise IllegalOperationError(
"Creating pipelines for a user other than yourself "
"is not supported."
)
return zen_store().create_pipeline(pipeline=pipeline)
create_pipeline_run(workspace_name_or_id, pipeline_run, auth_context=Security(oauth2_authentication), get_if_exists=False)
Creates a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline_run |
PipelineRunRequestModel |
Pipeline run to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
get_if_exists |
bool |
If a similar pipeline run already exists, return it instead of raising an error. |
False |
Returns:
Type | Description |
---|---|
PipelineRunResponseModel |
The created pipeline run. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the pipeline run does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUNS,
response_model=PipelineRunResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_pipeline_run(
workspace_name_or_id: Union[str, UUID],
pipeline_run: PipelineRunRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
get_if_exists: bool = False,
) -> PipelineRunResponseModel:
"""Creates a pipeline run.
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline_run: Pipeline run to create.
auth_context: Authentication context.
get_if_exists: If a similar pipeline run already exists, return it
instead of raising an error.
Returns:
The created pipeline run.
Raises:
IllegalOperationError: If the workspace or user specified in the
pipeline run does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if pipeline_run.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipeline runs outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if pipeline_run.user != auth_context.user.id:
raise IllegalOperationError(
"Creating pipeline runs for a user other than yourself "
"is not supported."
)
if get_if_exists:
return zen_store().get_or_create_run(pipeline_run=pipeline_run)[0]
return zen_store().create_run(pipeline_run=pipeline_run)
create_run_metadata(workspace_name_or_id, run_metadata, auth_context=Security(oauth2_authentication))
Creates run metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
run_metadata |
RunMetadataRequestModel |
The run metadata to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
List[zenml.models.run_metadata_models.RunMetadataResponseModel] |
The created run metadata. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the run metadata does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUN_METADATA,
response_model=List[RunMetadataResponseModel],
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_run_metadata(
workspace_name_or_id: Union[str, UUID],
run_metadata: RunMetadataRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> List[RunMetadataResponseModel]:
"""Creates run metadata.
Args:
workspace_name_or_id: Name or ID of the workspace.
run_metadata: The run metadata to create.
auth_context: Authentication context.
Returns:
The created run metadata.
Raises:
IllegalOperationError: If the workspace or user specified in the run
metadata does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(run_metadata.workspace)
if run_metadata.workspace != workspace.id:
raise IllegalOperationError(
"Creating run metadata outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if run_metadata.user != auth_context.user.id:
raise IllegalOperationError(
"Creating run metadata for a user other than yourself "
"is not supported."
)
return zen_store().create_run_metadata(run_metadata=run_metadata)
create_schedule(workspace_name_or_id, schedule, auth_context=Security(oauth2_authentication))
Creates a schedule.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
schedule |
ScheduleRequestModel |
Schedule to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ScheduleResponseModel |
The created schedule. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the schedule does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SCHEDULES,
response_model=ScheduleResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_schedule(
workspace_name_or_id: Union[str, UUID],
schedule: ScheduleRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ScheduleResponseModel:
"""Creates a schedule.
Args:
workspace_name_or_id: Name or ID of the workspace.
schedule: Schedule to create.
auth_context: Authentication context.
Returns:
The created schedule.
Raises:
IllegalOperationError: If the workspace or user specified in the
schedule does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if schedule.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipeline runs outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if schedule.user != auth_context.user.id:
raise IllegalOperationError(
"Creating pipeline runs for a user other than yourself "
"is not supported."
)
return zen_store().create_schedule(schedule=schedule)
create_secret(workspace_name_or_id, secret, auth_context=Security(oauth2_authentication))
Creates a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
secret |
SecretRequestModel |
Secret to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
SecretResponseModel |
The created secret. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the secret does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SECRETS,
response_model=SecretResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_secret(
workspace_name_or_id: Union[str, UUID],
secret: SecretRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> SecretResponseModel:
"""Creates a secret.
Args:
workspace_name_or_id: Name or ID of the workspace.
secret: Secret to create.
auth_context: Authentication context.
Returns:
The created secret.
Raises:
IllegalOperationError: If the workspace or user specified in the
secret does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if secret.workspace != workspace.id:
raise IllegalOperationError(
"Creating a secret outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if secret.user != auth_context.user.id:
raise IllegalOperationError(
"Creating secrets for a user other than yourself "
"is not supported."
)
return zen_store().create_secret(secret=secret)
create_service_connector(workspace_name_or_id, connector, auth_context=Security(oauth2_authentication))
Creates a service connector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
connector |
ServiceConnectorRequestModel |
Service connector to register. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ServiceConnectorResponseModel |
The created service connector. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the service connector does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SERVICE_CONNECTORS,
response_model=ServiceConnectorResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_service_connector(
workspace_name_or_id: Union[str, UUID],
connector: ServiceConnectorRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ServiceConnectorResponseModel:
"""Creates a service connector.
Args:
workspace_name_or_id: Name or ID of the workspace.
connector: Service connector to register.
auth_context: Authentication context.
Returns:
The created service connector.
Raises:
IllegalOperationError: If the workspace or user specified in the service
connector does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if connector.workspace != workspace.id:
raise IllegalOperationError(
"Creating connectors outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if connector.user != auth_context.user.id:
raise IllegalOperationError(
"Creating connectors for a user other than yourself "
"is not supported."
)
return zen_store().create_service_connector(service_connector=connector)
create_stack(workspace_name_or_id, stack, auth_context=Security(oauth2_authentication))
Creates a stack for a particular workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
stack |
StackRequestModel |
Stack to register. |
required |
auth_context |
AuthContext |
The authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
StackResponseModel |
The created stack. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the stack does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + STACKS,
response_model=StackResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_stack(
workspace_name_or_id: Union[str, UUID],
stack: StackRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> StackResponseModel:
"""Creates a stack for a particular workspace.
Args:
workspace_name_or_id: Name or ID of the workspace.
stack: Stack to register.
auth_context: The authentication context.
Returns:
The created stack.
Raises:
IllegalOperationError: If the workspace or user specified in the stack
does not match the current workspace or authenticated user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if stack.workspace != workspace.id:
raise IllegalOperationError(
"Creating stacks outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if stack.user != auth_context.user.id:
raise IllegalOperationError(
"Creating stacks for a user other than yourself "
"is not supported."
)
return zen_store().create_stack(stack=stack)
create_stack_component(workspace_name_or_id, component, auth_context=Security(oauth2_authentication))
Creates a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
component |
ComponentRequestModel |
Stack component to register. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
ComponentResponseModel |
The created stack component. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the stack component does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + STACK_COMPONENTS,
response_model=ComponentResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_stack_component(
workspace_name_or_id: Union[str, UUID],
component: ComponentRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> ComponentResponseModel:
"""Creates a stack component.
Args:
workspace_name_or_id: Name or ID of the workspace.
component: Stack component to register.
auth_context: Authentication context.
Returns:
The created stack component.
Raises:
IllegalOperationError: If the workspace or user specified in the stack
component does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if component.workspace != workspace.id:
raise IllegalOperationError(
"Creating components outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if component.user != auth_context.user.id:
raise IllegalOperationError(
"Creating components for a user other than yourself "
"is not supported."
)
# TODO: [server] if possible it should validate here that the configuration
# conforms to the flavor
return zen_store().create_stack_component(component=component)
create_workspace(workspace, _=Security(oauth2_authentication))
Creates a workspace based on the requestBody.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace |
WorkspaceRequestModel |
Workspace to create. |
required |
Returns:
Type | Description |
---|---|
WorkspaceResponseModel |
The created workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES,
response_model=WorkspaceResponseModel,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_workspace(
workspace: WorkspaceRequestModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> WorkspaceResponseModel:
"""Creates a workspace based on the requestBody.
# noqa: DAR401
Args:
workspace: Workspace to create.
Returns:
The created workspace.
"""
return zen_store().create_workspace(workspace=workspace)
delete_workspace(workspace_name_or_id, _=Security(oauth2_authentication))
Deletes a workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.delete(
WORKSPACES + "/{workspace_name_or_id}",
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def delete_workspace(
workspace_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> None:
"""Deletes a workspace.
Args:
workspace_name_or_id: Name or ID of the workspace.
"""
zen_store().delete_workspace(workspace_name_or_id=workspace_name_or_id)
get_or_create_pipeline_run(workspace_name_or_id, pipeline_run, auth_context=Security(oauth2_authentication))
Get or create a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline_run |
PipelineRunRequestModel |
Pipeline run to create. |
required |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Tuple[zenml.models.pipeline_run_models.PipelineRunResponseModel, bool] |
The pipeline run and a boolean indicating whether the run was created or not. |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the workspace or user specified in the pipeline run does not match the current workspace or authenticated user. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.post(
WORKSPACES + "/{workspace_name_or_id}" + RUNS + GET_OR_CREATE,
response_model=Tuple[PipelineRunResponseModel, bool],
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def get_or_create_pipeline_run(
workspace_name_or_id: Union[str, UUID],
pipeline_run: PipelineRunRequestModel,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.WRITE]
),
) -> Tuple[PipelineRunResponseModel, bool]:
"""Get or create a pipeline run.
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline_run: Pipeline run to create.
auth_context: Authentication context.
Returns:
The pipeline run and a boolean indicating whether the run was created
or not.
Raises:
IllegalOperationError: If the workspace or user specified in the
pipeline run does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if pipeline_run.workspace != workspace.id:
raise IllegalOperationError(
"Creating pipeline runs outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if pipeline_run.user != auth_context.user.id:
raise IllegalOperationError(
"Creating pipeline runs for a user other than yourself "
"is not supported."
)
return zen_store().get_or_create_run(pipeline_run=pipeline_run)
get_workspace(workspace_name_or_id, _=Security(oauth2_authentication))
Get a workspace for given name.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
Returns:
Type | Description |
---|---|
WorkspaceResponseModel |
The requested workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}",
response_model=WorkspaceResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_workspace(
workspace_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> WorkspaceResponseModel:
"""Get a workspace for given name.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
Returns:
The requested workspace.
"""
return zen_store().get_workspace(workspace_name_or_id=workspace_name_or_id)
get_workspace_statistics(workspace_name_or_id, _=Security(oauth2_authentication))
Gets statistics of a workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace to get statistics for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, int] |
All pipelines within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + STATISTICS,
response_model=Dict[str, str],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def get_workspace_statistics(
workspace_name_or_id: Union[str, UUID],
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Dict[str, int]:
"""Gets statistics of a workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace to get statistics for.
Returns:
All pipelines within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
return {
"stacks": zen_store().count_stacks(workspace_id=workspace.id),
"components": zen_store().count_stack_components(
workspace_id=workspace.id
),
"pipelines": zen_store().count_pipelines(workspace_id=workspace.id),
"runs": zen_store().count_runs(workspace_id=workspace.id),
}
list_runs(workspace_name_or_id, runs_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get pipeline runs according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
runs_filter_model |
PipelineRunFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineRunResponseModel] |
The pipeline runs according to query filters. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + RUNS,
response_model=Page[PipelineRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_runs(
workspace_name_or_id: Union[str, UUID],
runs_filter_model: PipelineRunFilterModel = Depends(
make_dependable(PipelineRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineRunResponseModel]:
"""Get pipeline runs according to query filters.
Args:
workspace_name_or_id: Name or ID of the workspace.
runs_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The pipeline runs according to query filters.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
runs_filter_model.set_scope_workspace(workspace.id)
return zen_store().list_runs(runs_filter_model=runs_filter_model)
list_service_connector_resources(workspace_name_or_id, connector_type=None, resource_type=None, resource_id=None, auth_context=Security(oauth2_authentication))
List resources that can be accessed by service connectors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
connector_type |
Optional[str] |
the service connector type identifier to filter by. |
None |
resource_type |
Optional[str] |
the resource type identifier to filter by. |
None |
resource_id |
Optional[str] |
the resource identifier to filter by. |
None |
auth_context |
AuthContext |
Authentication context. |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
List[zenml.models.service_connector_models.ServiceConnectorResourcesModel] |
The matching list of resources that available service connectors have access to. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES
+ "/{workspace_name_or_id}"
+ SERVICE_CONNECTORS
+ SERVICE_CONNECTOR_RESOURCES,
response_model=List[ServiceConnectorResourcesModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_service_connector_resources(
workspace_name_or_id: Union[str, UUID],
connector_type: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> List[ServiceConnectorResourcesModel]:
"""List resources that can be accessed by service connectors.
Args:
workspace_name_or_id: Name or ID of the workspace.
connector_type: the service connector type identifier to filter by.
resource_type: the resource type identifier to filter by.
resource_id: the resource identifier to filter by.
auth_context: Authentication context.
Returns:
The matching list of resources that available service
connectors have access to.
"""
return zen_store().list_service_connector_resources(
user_name_or_id=auth_context.user.id,
workspace_name_or_id=workspace_name_or_id,
connector_type=connector_type,
resource_type=resource_type,
resource_id=resource_id,
)
list_team_role_assignments_for_workspace(workspace_name_or_id, team_role_assignment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all roles that are assigned to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
team_role_assignment_filter_model |
TeamRoleAssignmentFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[TeamRoleAssignmentResponseModel] |
A list of all roles that are assigned to a team. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + TEAM_ROLE_ASSIGNMENTS,
response_model=Page[TeamRoleAssignmentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_team_role_assignments_for_workspace(
workspace_name_or_id: Union[str, UUID],
team_role_assignment_filter_model: TeamRoleAssignmentFilterModel = Depends(
make_dependable(TeamRoleAssignmentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[TeamRoleAssignmentResponseModel]:
"""Returns a list of all roles that are assigned to a team.
Args:
workspace_name_or_id: Name or ID of the workspace.
team_role_assignment_filter_model: Filter model used for pagination,
sorting, filtering
Returns:
A list of all roles that are assigned to a team.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
team_role_assignment_filter_model.workspace_id = workspace.id
return zen_store().list_team_role_assignments(
team_role_assignment_filter_model=team_role_assignment_filter_model
)
list_user_role_assignments_for_workspace(workspace_name_or_id, user_role_assignment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Returns a list of all roles that are assigned to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
user_role_assignment_filter_model |
UserRoleAssignmentFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[UserRoleAssignmentResponseModel] |
A list of all roles that are assigned to a team. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + USER_ROLE_ASSIGNMENTS,
response_model=Page[UserRoleAssignmentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_user_role_assignments_for_workspace(
workspace_name_or_id: Union[str, UUID],
user_role_assignment_filter_model: UserRoleAssignmentFilterModel = Depends(
make_dependable(UserRoleAssignmentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[UserRoleAssignmentResponseModel]:
"""Returns a list of all roles that are assigned to a team.
Args:
workspace_name_or_id: Name or ID of the workspace.
user_role_assignment_filter_model: Filter model used for pagination,
sorting, filtering
Returns:
A list of all roles that are assigned to a team.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
user_role_assignment_filter_model.workspace_id = workspace.id
return zen_store().list_user_role_assignments(
user_role_assignment_filter_model=user_role_assignment_filter_model
)
list_workspace_builds(workspace_name_or_id, build_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets builds defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
build_filter_model |
PipelineBuildFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineBuildResponseModel] |
All builds within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_BUILDS,
response_model=Page[PipelineBuildResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_builds(
workspace_name_or_id: Union[str, UUID],
build_filter_model: PipelineBuildFilterModel = Depends(
make_dependable(PipelineBuildFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineBuildResponseModel]:
"""Gets builds defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
build_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
All builds within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
build_filter_model.set_scope_workspace(workspace.id)
return zen_store().list_builds(build_filter_model=build_filter_model)
list_workspace_code_repositories(workspace_name_or_id, filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets code repositories defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
filter_model |
CodeRepositoryFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[CodeRepositoryResponseModel] |
All code repositories within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + CODE_REPOSITORIES,
response_model=Page[CodeRepositoryResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_code_repositories(
workspace_name_or_id: Union[str, UUID],
filter_model: CodeRepositoryFilterModel = Depends(
make_dependable(CodeRepositoryFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[CodeRepositoryResponseModel]:
"""Gets code repositories defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
filter_model: Filter model used for pagination, sorting,
filtering
Returns:
All code repositories within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
filter_model.set_scope_workspace(workspace.id)
return zen_store().list_code_repositories(filter_model=filter_model)
list_workspace_deployments(workspace_name_or_id, deployment_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets deployments defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
deployment_filter_model |
PipelineDeploymentFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineDeploymentResponseModel] |
All deployments within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_DEPLOYMENTS,
response_model=Page[PipelineDeploymentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_deployments(
workspace_name_or_id: Union[str, UUID],
deployment_filter_model: PipelineDeploymentFilterModel = Depends(
make_dependable(PipelineDeploymentFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineDeploymentResponseModel]:
"""Gets deployments defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
deployment_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
All deployments within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
deployment_filter_model.set_scope_workspace(workspace.id)
return zen_store().list_deployments(
deployment_filter_model=deployment_filter_model
)
list_workspace_model_version_artifact_links(workspace_name_or_id, model_name_or_id, model_version_name_or_id, model_version_artifact_link_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get model version to artifact links according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model. |
required |
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model version. |
required |
model_version_artifact_link_filter_model |
ModelVersionArtifactFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelVersionArtifactResponseModel] |
The model version to artifact links according to query filters. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ ARTIFACTS,
response_model=Page[ModelVersionArtifactResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_model_version_artifact_links(
workspace_name_or_id: Union[str, UUID],
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends(
make_dependable(ModelVersionArtifactFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links according to query filters.
Args:
model_name_or_id: Name or ID of the model.
workspace_name_or_id: Name or ID of the workspace.
model_version_name_or_id: Name or ID of the model version.
model_version_artifact_link_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The model version to artifact links according to query filters.
"""
workspace_id = zen_store().get_workspace(workspace_name_or_id).id
model_version_artifact_link_filter_model.set_scope_workspace(workspace_id)
return zen_store().list_model_version_artifact_links(
model_version_artifact_link_filter_model=model_version_artifact_link_filter_model,
)
list_workspace_model_version_pipeline_run_links(workspace_name_or_id, model_name_or_id, model_version_name_or_id, model_version_pipeline_run_link_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get model version to pipeline links according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model. |
required |
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the model version. |
required |
model_version_pipeline_run_link_filter_model |
ModelVersionPipelineRunFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelVersionPipelineRunResponseModel] |
The model version to pipeline run links according to query filters. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_name_or_id}"
+ RUNS,
response_model=Page[ModelVersionPipelineRunResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_model_version_pipeline_run_links(
workspace_name_or_id: Union[str, UUID],
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends(
make_dependable(ModelVersionPipelineRunFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get model version to pipeline links according to query filters.
Args:
model_name_or_id: Name or ID of the model.
workspace_name_or_id: Name or ID of the workspace.
model_version_name_or_id: Name or ID of the model version.
model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The model version to pipeline run links according to query filters.
"""
workspace_id = zen_store().get_workspace(workspace_name_or_id).id
model_version_pipeline_run_link_filter_model.set_scope_workspace(
workspace_id
)
return zen_store().list_model_version_pipeline_run_links(
model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model,
)
list_workspace_model_versions(workspace_name_or_id, model_version_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get model versions according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_version_filter_model |
ModelVersionFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelVersionResponseModel] |
The model versions according to query filters. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS,
response_model=Page[ModelVersionResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_model_versions(
workspace_name_or_id: Union[str, UUID],
model_version_filter_model: ModelVersionFilterModel = Depends(
make_dependable(ModelVersionFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelVersionResponseModel]:
"""Get model versions according to query filters.
Args:
workspace_name_or_id: Name or ID of the workspace.
model_version_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The model versions according to query filters.
"""
workspace_id = zen_store().get_workspace(workspace_name_or_id).id
model_version_filter_model.set_scope_workspace(workspace_id)
return zen_store().list_model_versions(
model_version_filter_model=model_version_filter_model,
)
list_workspace_models(workspace_name_or_id, model_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Get models according to query filters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
model_filter_model |
ModelFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[ModelResponseModel] |
The models according to query filters. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + MODELS,
response_model=Page[ModelResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_models(
workspace_name_or_id: Union[str, UUID],
model_filter_model: ModelFilterModel = Depends(
make_dependable(ModelFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[ModelResponseModel]:
"""Get models according to query filters.
Args:
workspace_name_or_id: Name or ID of the workspace.
model_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
The models according to query filters.
"""
workspace_id = zen_store().get_workspace(workspace_name_or_id).id
model_filter_model.set_scope_workspace(workspace_id)
return zen_store().list_models(
model_filter_model=model_filter_model,
)
list_workspace_pipelines(workspace_name_or_id, pipeline_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Gets pipelines defined for a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
pipeline_filter_model |
PipelineFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[PipelineResponseModel] |
All pipelines within the workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + PIPELINES,
response_model=Page[PipelineResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_pipelines(
workspace_name_or_id: Union[str, UUID],
pipeline_filter_model: PipelineFilterModel = Depends(
make_dependable(PipelineFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[PipelineResponseModel]:
"""Gets pipelines defined for a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
pipeline_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
All pipelines within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
pipeline_filter_model.set_scope_workspace(workspace.id)
return zen_store().list_pipelines(
pipeline_filter_model=pipeline_filter_model
)
list_workspace_service_connectors(workspace_name_or_id, connector_filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
List service connectors that are part of a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
connector_filter_model |
ServiceConnectorFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
Authentication Context |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ServiceConnectorResponseModel] |
All service connectors part of the specified workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + SERVICE_CONNECTORS,
response_model=Page[ServiceConnectorResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_service_connectors(
workspace_name_or_id: Union[str, UUID],
connector_filter_model: ServiceConnectorFilterModel = Depends(
make_dependable(ServiceConnectorFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[ServiceConnectorResponseModel]:
"""List service connectors that are part of a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
connector_filter_model: Filter model used for pagination, sorting,
filtering
auth_context: Authentication Context
Returns:
All service connectors part of the specified workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
connector_filter_model.set_scope_workspace(workspace.id)
connector_filter_model.set_scope_user(user_id=auth_context.user.id)
return zen_store().list_service_connectors(
filter_model=connector_filter_model
)
list_workspace_stack_components(workspace_name_or_id, component_filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
List stack components that are part of a specific workspace.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
component_filter_model |
ComponentFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
Authentication Context |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[ComponentResponseModel] |
All stack components part of the specified workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + STACK_COMPONENTS,
response_model=Page[ComponentResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_stack_components(
workspace_name_or_id: Union[str, UUID],
component_filter_model: ComponentFilterModel = Depends(
make_dependable(ComponentFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[ComponentResponseModel]:
"""List stack components that are part of a specific workspace.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
component_filter_model: Filter model used for pagination, sorting,
filtering
auth_context: Authentication Context
Returns:
All stack components part of the specified workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
component_filter_model.set_scope_workspace(workspace.id)
component_filter_model.set_scope_user(user_id=auth_context.user.id)
return zen_store().list_stack_components(
component_filter_model=component_filter_model
)
list_workspace_stacks(workspace_name_or_id, stack_filter_model=Depends(init_cls_and_handle_errors), auth_context=Security(oauth2_authentication))
Get stacks that are part of a specific workspace for the user.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
Union[str, uuid.UUID] |
Name or ID of the workspace. |
required |
stack_filter_model |
StackFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
auth_context |
AuthContext |
Authentication Context |
Security(oauth2_authentication) |
Returns:
Type | Description |
---|---|
Page[StackResponseModel] |
All stacks part of the specified workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES + "/{workspace_name_or_id}" + STACKS,
response_model=Page[StackResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspace_stacks(
workspace_name_or_id: Union[str, UUID],
stack_filter_model: StackFilterModel = Depends(
make_dependable(StackFilterModel)
),
auth_context: AuthContext = Security(
authorize, scopes=[PermissionType.READ]
),
) -> Page[StackResponseModel]:
"""Get stacks that are part of a specific workspace for the user.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace.
stack_filter_model: Filter model used for pagination, sorting, filtering
auth_context: Authentication Context
Returns:
All stacks part of the specified workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
stack_filter_model.set_scope_workspace(workspace.id)
stack_filter_model.set_scope_user(user_id=auth_context.user.id)
return zen_store().list_stacks(stack_filter_model=stack_filter_model)
list_workspaces(workspace_filter_model=Depends(init_cls_and_handle_errors), _=Security(oauth2_authentication))
Lists all workspaces in the organization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_filter_model |
WorkspaceFilterModel |
Filter model used for pagination, sorting, filtering |
Depends(init_cls_and_handle_errors) |
Returns:
Type | Description |
---|---|
Page[WorkspaceResponseModel] |
A list of workspaces. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.get(
WORKSPACES,
response_model=Page[WorkspaceResponseModel],
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def list_workspaces(
workspace_filter_model: WorkspaceFilterModel = Depends(
make_dependable(WorkspaceFilterModel)
),
_: AuthContext = Security(authorize, scopes=[PermissionType.READ]),
) -> Page[WorkspaceResponseModel]:
"""Lists all workspaces in the organization.
Args:
workspace_filter_model: Filter model used for pagination, sorting,
filtering
Returns:
A list of workspaces.
"""
return zen_store().list_workspaces(
workspace_filter_model=workspace_filter_model
)
update_workspace(workspace_name_or_id, workspace_update, _=Security(oauth2_authentication))
Get a workspace for given name.
noqa: DAR401
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace_name_or_id |
UUID |
Name or ID of the workspace to update. |
required |
workspace_update |
WorkspaceUpdateModel |
the workspace to use to update |
required |
Returns:
Type | Description |
---|---|
WorkspaceResponseModel |
The updated workspace. |
Source code in zenml/zen_server/routers/workspaces_endpoints.py
@router.put(
WORKSPACES + "/{workspace_name_or_id}",
response_model=WorkspaceResponseModel,
responses={401: error_response, 404: error_response, 422: error_response},
)
@handle_exceptions
def update_workspace(
workspace_name_or_id: UUID,
workspace_update: WorkspaceUpdateModel,
_: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]),
) -> WorkspaceResponseModel:
"""Get a workspace for given name.
# noqa: DAR401
Args:
workspace_name_or_id: Name or ID of the workspace to update.
workspace_update: the workspace to use to update
Returns:
The updated workspace.
"""
return zen_store().update_workspace(
workspace_id=workspace_name_or_id,
workspace_update=workspace_update,
)
utils
Util functions for the ZenML Server.
get_active_deployment(local=False)
Get the active local or remote server deployment.
Call this function to retrieve the local or remote server deployment that was last provisioned on this machine.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
local |
bool |
Whether to return the local active deployment or the remote one. |
False |
Returns:
Type | Description |
---|---|
Optional[ServerDeployment] |
The local or remote active server deployment or None, if no deployment was found. |
Source code in zenml/zen_server/utils.py
def get_active_deployment(local: bool = False) -> Optional["ServerDeployment"]:
"""Get the active local or remote server deployment.
Call this function to retrieve the local or remote server deployment that
was last provisioned on this machine.
Args:
local: Whether to return the local active deployment or the remote one.
Returns:
The local or remote active server deployment or None, if no deployment
was found.
"""
from zenml.zen_server.deploy.deployer import ServerDeployer
deployer = ServerDeployer()
if local:
servers = deployer.list_servers(provider_type=ServerProviderType.LOCAL)
if not servers:
servers = deployer.list_servers(
provider_type=ServerProviderType.DOCKER
)
else:
servers = deployer.list_servers()
if not servers:
return None
for server in servers:
if server.config.provider in [
ServerProviderType.LOCAL,
ServerProviderType.DOCKER,
]:
if local:
return server
elif not local:
return server
return None
get_active_server_details()
Get the URL of the current ZenML Server.
When multiple servers are present, the following precedence is used to determine which server to use: - If the client is connected to a server, that server has precedence. - If no server is connected, a server that was deployed remotely has precedence over a server that was deployed locally.
Returns:
Type | Description |
---|---|
Tuple[str, Optional[int]] |
The URL and port of the currently active server. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no server is active. |
Source code in zenml/zen_server/utils.py
def get_active_server_details() -> Tuple[str, Optional[int]]:
"""Get the URL of the current ZenML Server.
When multiple servers are present, the following precedence is used to
determine which server to use:
- If the client is connected to a server, that server has precedence.
- If no server is connected, a server that was deployed remotely has
precedence over a server that was deployed locally.
Returns:
The URL and port of the currently active server.
Raises:
RuntimeError: If no server is active.
"""
# Check for connected servers first
gc = GlobalConfiguration()
if not gc.uses_default_store() and gc.store is not None:
logger.debug("Getting URL of connected server.")
parsed_url = urlparse(gc.store.url)
return f"{parsed_url.scheme}://{parsed_url.hostname}", parsed_url.port
# Else, check for deployed servers
server = get_active_deployment(local=False)
if server:
logger.debug("Getting URL of remote server.")
else:
server = get_active_deployment(local=True)
logger.debug("Getting URL of local server.")
if server and server.status and server.status.url:
if isinstance(server.config, LocalServerDeploymentConfig):
return server.status.url, server.config.port
return server.status.url, None
raise RuntimeError(
"ZenML is not connected to any server right now. Please use "
"`zenml connect` to connect to a server or spin up a new local server "
"via `zenml up`."
)
get_ip_location(ip_address)
Get the location of the given IP address.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ip_address |
str |
The IP address to get the location for. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str, str] |
A tuple of city, region, country. |
Source code in zenml/zen_server/utils.py
def get_ip_location(ip_address: str) -> Tuple[str, str, str]:
"""Get the location of the given IP address.
Args:
ip_address: The IP address to get the location for.
Returns:
A tuple of city, region, country.
"""
import ipinfo # type: ignore[import]
try:
handler = ipinfo.getHandler()
details = handler.getDetails(ip_address)
return (
details.city,
details.region,
details.country_name,
)
except Exception:
logger.exception(f"Could not get IP location for {ip_address}.")
return "", "", ""
handle_exceptions(func)
Decorator to handle exceptions in the API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
~F |
Function to decorate. |
required |
Returns:
Type | Description |
---|---|
~F |
Decorated function. |
Source code in zenml/zen_server/utils.py
def handle_exceptions(func: F) -> F:
"""Decorator to handle exceptions in the API.
Args:
func: Function to decorate.
Returns:
Decorated function.
"""
@wraps(func)
def decorated(*args: Any, **kwargs: Any) -> Any:
# These imports can't happen at module level as this module is also
# used by the CLI when installed without the `server` extra
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from zenml.zen_server.auth import AuthContext, set_auth_context
for arg in args:
if isinstance(arg, AuthContext):
set_auth_context(arg)
break
else:
for _, arg in kwargs.items():
if isinstance(arg, AuthContext):
set_auth_context(arg)
break
try:
return func(*args, **kwargs)
except OAuthError as error:
# The OAuthError is special because it needs to have a JSON response
return JSONResponse(
status_code=error.status_code,
content=error.to_dict(),
)
except HTTPException:
raise
except Exception as error:
logger.exception("API error")
http_exception = http_exception_from_error(error)
raise http_exception
return cast(F, decorated)
initialize_zen_store()
Initialize the ZenML Store.
Exceptions:
Type | Description |
---|---|
ValueError |
If the ZenML Store is using a REST back-end. |
Source code in zenml/zen_server/utils.py
def initialize_zen_store() -> None:
"""Initialize the ZenML Store.
Raises:
ValueError: If the ZenML Store is using a REST back-end.
"""
logger.debug("Initializing ZenML Store for FastAPI...")
# Use an environment variable to flag the instance as a server
os.environ[ENV_ZENML_SERVER] = "true"
zen_store_ = GlobalConfiguration().zen_store
if not isinstance(zen_store_, SqlZenStore):
raise ValueError(
"Server cannot be started with a REST store type. Make sure you "
"configure ZenML to use a non-networked store backend "
"when trying to start the ZenML Server."
)
global _zen_store
_zen_store = zen_store_
make_dependable(cls)
This function makes a pydantic model usable for fastapi query parameters.
Additionally, it converts InternalServerError
s that would happen due to
pydantic.ValidationError
into 422 responses that signal an invalid
request.
Check out https://github.com/tiangolo/fastapi/issues/1474 for context.
!!! usage def f(model: Model = Depends(make_dependable(Model))): ...
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cls |
Type[pydantic.main.BaseModel] |
The model class. |
required |
Returns:
Type | Description |
---|---|
Callable[..., Any] |
Function to use in FastAPI |
Source code in zenml/zen_server/utils.py
def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
"""This function makes a pydantic model usable for fastapi query parameters.
Additionally, it converts `InternalServerError`s that would happen due to
`pydantic.ValidationError` into 422 responses that signal an invalid
request.
Check out https://github.com/tiangolo/fastapi/issues/1474 for context.
Usage:
def f(model: Model = Depends(make_dependable(Model))):
...
Args:
cls: The model class.
Returns:
Function to use in FastAPI `Depends`.
"""
def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
from fastapi import HTTPException
try:
inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
return cls(*args, **kwargs)
except ValidationError as e:
for error in e.errors():
error["loc"] = tuple(["query"] + list(error["loc"]))
raise HTTPException(422, detail=e.errors())
init_cls_and_handle_errors.__signature__ = inspect.signature(cls) # type: ignore[attr-defined]
return init_cls_and_handle_errors
server_config()
Returns the ZenML Server configuration.
Returns:
Type | Description |
---|---|
ServerConfiguration |
The ZenML Server configuration. |
Source code in zenml/zen_server/utils.py
def server_config() -> ServerConfiguration:
"""Returns the ZenML Server configuration.
Returns:
The ZenML Server configuration.
"""
global _server_config
if _server_config is None:
_server_config = ServerConfiguration.get_server_config()
return _server_config
zen_store()
Initialize the ZenML Store.
Returns:
Type | Description |
---|---|
SqlZenStore |
The ZenML Store. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the ZenML Store has not been initialized. |
Source code in zenml/zen_server/utils.py
def zen_store() -> "SqlZenStore":
"""Initialize the ZenML Store.
Returns:
The ZenML Store.
Raises:
RuntimeError: If the ZenML Store has not been initialized.
"""
global _zen_store
if _zen_store is None:
raise RuntimeError("ZenML Store not initialized")
return _zen_store
zen_server_api
Zen Server API.
catch_all(request, file_path)
Dashboard endpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
file_path |
str |
Path to a file in the dashboard root folder. |
required |
Returns:
Type | Description |
---|---|
Any |
The ZenML dashboard. |
Exceptions:
Type | Description |
---|---|
HTTPException |
404 error if requested a non-existent static file or if the dashboard files are not included. |
Source code in zenml/zen_server/zen_server_api.py
@app.get("/{file_path:path}", include_in_schema=False)
def catch_all(request: Request, file_path: str) -> Any:
"""Dashboard endpoint.
Args:
request: Request object.
file_path: Path to a file in the dashboard root folder.
Returns:
The ZenML dashboard.
Raises:
HTTPException: 404 error if requested a non-existent static file or if
the dashboard files are not included.
"""
# some static files need to be served directly from the root dashboard
# directory
if file_path and file_path in root_static_files:
logger.debug(f"Returning static file: {file_path}")
full_path = os.path.join(relative_path(DASHBOARD_DIRECTORY), file_path)
return FileResponse(full_path)
tokens = file_path.split("/")
if len(tokens) == 1 and not request.query_params:
logger.debug(f"Requested non-existent static file: {file_path}")
raise HTTPException(status_code=404)
if not os.path.isfile(
os.path.join(relative_path(DASHBOARD_DIRECTORY), "index.html")
):
raise HTTPException(status_code=404)
# everything else is directed to the index.html file that hosts the
# single-page application
return templates.TemplateResponse("index.html", {"request": request})
dashboard(request)
Dashboard endpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
Request object. |
required |
Returns:
Type | Description |
---|---|
Any |
The ZenML dashboard. |
Exceptions:
Type | Description |
---|---|
HTTPException |
If the dashboard files are not included. |
Source code in zenml/zen_server/zen_server_api.py
@app.get("/", include_in_schema=False)
def dashboard(request: Request) -> Any:
"""Dashboard endpoint.
Args:
request: Request object.
Returns:
The ZenML dashboard.
Raises:
HTTPException: If the dashboard files are not included.
"""
if not os.path.isfile(
os.path.join(relative_path(DASHBOARD_DIRECTORY), "index.html")
):
raise HTTPException(status_code=404)
return templates.TemplateResponse("index.html", {"request": request})
get_root_static_files()
Get the list of static files in the root dashboard directory.
These files are static files that are not in the /static subdirectory that need to be served as static files under the root URL path.
Returns:
Type | Description |
---|---|
List[str] |
List of static files in the root directory. |
Source code in zenml/zen_server/zen_server_api.py
def get_root_static_files() -> List[str]:
"""Get the list of static files in the root dashboard directory.
These files are static files that are not in the /static subdirectory
that need to be served as static files under the root URL path.
Returns:
List of static files in the root directory.
"""
root_path = relative_path(DASHBOARD_DIRECTORY)
if not os.path.isdir(root_path):
return []
files = []
for file in os.listdir(root_path):
if file == "index.html":
# this is served separately
continue
if isfile(os.path.join(root_path, file)):
files.append(file)
return files
health()
Get health status of the server.
Returns:
Type | Description |
---|---|
str |
String representing the health status of the server. |
Source code in zenml/zen_server/zen_server_api.py
@app.head(HEALTH, include_in_schema=False)
@app.get(HEALTH)
def health() -> str:
"""Get health status of the server.
Returns:
String representing the health status of the server.
"""
return "OK"
infer_source_context(request, call_next)
async
A middleware to track the source of an event.
It extracts the source context from the header of incoming requests and applies it to the ZenML source context on the API side. This way, the outgoing analytics request can append it as an additional field.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Request |
the incoming request object. |
required |
call_next |
Any |
a function that will receive the request as a parameter and pass it to the corresponding path operation. |
required |
Returns:
Type | Description |
---|---|
Any |
the response to the request. |
Source code in zenml/zen_server/zen_server_api.py
@app.middleware("http")
async def infer_source_context(request: Request, call_next: Any) -> Any:
"""A middleware to track the source of an event.
It extracts the source context from the header of incoming requests
and applies it to the ZenML source context on the API side. This way, the
outgoing analytics request can append it as an additional field.
Args:
request: the incoming request object.
call_next: a function that will receive the request as a parameter and
pass it to the corresponding path operation.
Returns:
the response to the request.
"""
try:
s = request.headers.get(
source_context.name,
default=SourceContextTypes.API.value,
)
source_context.set(SourceContextTypes(s))
except Exception as e:
logger.warning(
f"An unexpected error occurred while getting the source "
f"context: {e}"
)
source_context.set(SourceContextTypes.API)
return await call_next(request)
initialize()
Initialize the ZenML server.
Source code in zenml/zen_server/zen_server_api.py
@app.on_event("startup")
def initialize() -> None:
"""Initialize the ZenML server."""
# IMPORTANT: these need to be run before the fastapi app starts, to avoid
# race conditions
initialize_zen_store()
invalid_api(invalid_api_path)
Invalid API endpoint.
All API endpoints that are not defined in the API routers will be redirected to this endpoint and will return a 404 error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
invalid_api_path |
str |
Invalid API path. |
required |
Exceptions:
Type | Description |
---|---|
HTTPException |
404 error. |
Source code in zenml/zen_server/zen_server_api.py
@app.get(
API + "/{invalid_api_path:path}", status_code=404, include_in_schema=False
)
def invalid_api(invalid_api_path: str) -> None:
"""Invalid API endpoint.
All API endpoints that are not defined in the API routers will be
redirected to this endpoint and will return a 404 error.
Args:
invalid_api_path: Invalid API path.
Raises:
HTTPException: 404 error.
"""
logger.debug(f"Invalid API path requested: {invalid_api_path}")
raise HTTPException(status_code=404)
relative_path(rel)
Get the absolute path of a path relative to the ZenML server module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rel |
str |
Relative path. |
required |
Returns:
Type | Description |
---|---|
str |
Absolute path. |
Source code in zenml/zen_server/zen_server_api.py
def relative_path(rel: str) -> str:
"""Get the absolute path of a path relative to the ZenML server module.
Args:
rel: Relative path.
Returns:
Absolute path.
"""
return os.path.join(os.path.dirname(__file__), rel)
validation_exception_handler(request, exc)
Custom validation exception handler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Any |
The request. |
required |
exc |
Exception |
The exception. |
required |
Returns:
Type | Description |
---|---|
ORJSONResponse |
The error response formatted using the ZenML API conventions. |
Source code in zenml/zen_server/zen_server_api.py
@app.exception_handler(RequestValidationError)
def validation_exception_handler(
request: Any, exc: Exception
) -> ORJSONResponse:
"""Custom validation exception handler.
Args:
request: The request.
exc: The exception.
Returns:
The error response formatted using the ZenML API conventions.
"""
return ORJSONResponse(error_detail(exc, ValueError), status_code=422)